diff --git a/.coveragerc b/.coveragerc index d84a6fc8..c79b5374 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,10 @@ include = omit = kazoo/tests/* kazoo/testing/* + +# Note - this is a copy of the default exclusions from coverage 7.10.1 +[report] +exclude_lines = + #\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER) + ^\s*(((async )?def .*?)?\)(\s*->.*?)?:\s*)?\.\.\.\s*(#|$) + if (typing\.)?TYPE_CHECKING: diff --git a/.flake8 b/.flake8 index ba8a3d67..839f0322 100644 --- a/.flake8 +++ b/.flake8 @@ -3,12 +3,18 @@ builtins = _ exclude = .git, __pycache__, - .venv/,venv/, + .venv/, + venv*/, .tox/, - build/,dist/,*egg, + build/, + dist/, + *egg, docs/conf.py, - zookeeper/ + zookeeper/, # See black's documentation for E203 max-line-length = 79 -extend-ignore = BLK100,E203 +# I am not sure what version of flake8 hound is using +# but it gives a lot of undefined names for comments (F821) +# and redefinition of unused variables (F811) +extend-ignore = BLK100,E203,F811,F821 diff --git a/.gitignore b/.gitignore index 1c2b4b24..cc13a345 100644 --- a/.gitignore +++ b/.gitignore @@ -29,14 +29,15 @@ zookeeper/ .idea .project .pydevproject -.tox +.tox*/ venv*/ /.settings /.metadata +__pycache__/ !.gitignore !.git-blame-ignore-revs -.vscode/settings.json +.vscode/ .*_cache/ coverage.xml diff --git a/kazoo/client.py b/kazoo/client.py index 3029d1c5..b7a881ea 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -1,4 +1,7 @@ """Kazoo Zookeeper Client""" + +from __future__ import annotations + from collections import defaultdict, deque from functools import partial import inspect @@ -6,6 +9,21 @@ from os.path import split import re import warnings +from types import TracebackType +from typing import ( + cast, + overload, + Any, + Callable, + Deque, + Literal, + Optional, + Sequence, + Set, + TypedDict, + TYPE_CHECKING, +) +from typing_extensions import ParamSpec from kazoo.exceptions import ( AuthFailedError, @@ -63,6 +81,10 @@ from kazoo.recipe.queue import Queue, LockingQueue from kazoo.recipe.watchers import ChildrenWatch, DataWatch +if TYPE_CHECKING: + from kazoo.interfaces import Event, IAsyncResult, IHandler + from kazoo.protocol.states import ZnodeStat + CLOSED_STATES = ( KeeperState.EXPIRED_SESSION, @@ -88,6 +110,27 @@ retry_max_delay="max_delay", ) +# Signature for functions called by add_listener +ListenerFunc = Callable[[KazooState], Optional[bool]] + +# Signatures for get, get_children and exists watches +WatchFunc = Callable[[WatchedEvent], Optional[bool]] + +GenericArgs = ParamSpec("GenericArgs") + + +# Kazoo retry parameters +class KazooRetryParams(TypedDict, total=False): + max_tries: int + delay: float + backoff: int + max_jitter: float + max_delay: float + ignore_expire: bool + sleep_func: Callable[[float], None] + deadline: float + interrupt: Callable[[], bool] + class KazooClient(object): """An Apache Zookeeper Python client supporting alternate callback @@ -100,29 +143,29 @@ class KazooClient(object): """ - def __init__( + def __init__( # type: ignore[misc] self, - hosts="127.0.0.1:2181", - timeout=10.0, - client_id=None, - handler=None, - default_acl=None, - auth_data=None, - sasl_options=None, - read_only=None, - randomize_hosts=True, - connection_retry=None, - command_retry=None, - logger=None, - keyfile=None, - keyfile_password=None, - certfile=None, - ca=None, - use_ssl=False, - verify_certs=True, - check_hostname=False, - **kwargs, - ): + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: set[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Any, + ) -> None: """Create a :class:`KazooClient` instance. All time arguments are in seconds. @@ -234,8 +277,15 @@ def __init__( self.auth_data = auth_data if auth_data else set([]) self.default_acl = default_acl self.randomize_hosts = randomize_hosts - self.hosts = None - self.chroot = None + # Note: hosts and chroot are set by set_hosts, which also checks for + # chroot changes at runtime, so we initialize them to None here to + # avoid confusion with the empty string that set_hosts would set them + # to. This is massively hacky as set_hosts is only called from here + # anyway, but I want to make this change minimally invasive. + # we should really do self.hosts, self.chroot = self.set_hosts(hosts) + # and have set_hosts return the hosts and chroot + self.hosts: list[tuple[str, int]] = None # type: ignore[assignment] + self.chroot: str = None # type: ignore[assignment] self.set_hosts(hosts) self.use_ssl = use_ssl @@ -247,11 +297,15 @@ def __init__( self.ca = ca # Curator like simplified state tracking, and listeners for # state transitions - self._state = KeeperState.CLOSED - self.state = KazooState.LOST - self.state_listeners = set() - self._child_watchers = defaultdict(set) - self._data_watchers = defaultdict(set) + self._state: KeeperState = KeeperState.CLOSED + self.state: KazooState = KazooState.LOST + self.state_listeners: set[ListenerFunc] = set() + self._child_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) + self._data_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) self._reset() self.read_only = read_only @@ -272,7 +326,14 @@ def __init__( self._stopped.set() self._writer_stopped.set() - self.retry = self._conn_retry = None + # This is kind of gross but we need to set these to something so that + # the type checker will understand that they are set by the time they + # are used and that they have the right type. + # We would do better to use a few variables/functions instead of + # overloading self.retry but this is a bit less invasive to the code + # and the type checker can understand it with a few hacks + self.retry: KazooRetry = None # type: ignore[assignment] + self._conn_retry: KazooRetry = None # type: ignore[assignment] if type(connection_retry) is dict: self._conn_retry = KazooRetry(**connection_retry) @@ -299,7 +360,11 @@ def __init__( ) if self.retry is None or self._conn_retry is None: - old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS) + # Note: because of the hacks at line 280, mypy thinks this is + # unreachable + old_retry_keys = dict( # type: ignore[unreachable] + _RETRY_COMPAT_DEFAULTS + ) for key in old_retry_keys: try: old_retry_keys[key] = kwargs.pop(key) @@ -320,11 +385,13 @@ def __init__( if self._conn_retry is None: self._conn_retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) if self.retry is None: self.retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) # Managing legacy SASL options @@ -372,10 +439,20 @@ def __init__( # to avoid shared retry counts self._retry = self.retry - def _retry(*args, **kwargs): - return self._retry.copy()(*args, **kwargs) - - self.retry = _retry + def _retry( + func: Callable[GenericArgs, KazooRetry.RETRY_RETURN], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> KazooRetry.RETRY_RETURN: + return self._retry.copy()(func, *args, **kwargs) + + # (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", + # variable has type "KazooRetry") so basically self.retry needs to be + # set to that and then the type checker will understand that + # self.retry.copy() is a valid call. This is just a mess and needs the + # code rearranging to be more mypy friendly but this is the least + # invasive way to do it for now + self.retry = _retry # type: ignore[assignment] self.Barrier = partial(Barrier, self) self.Counter = partial(Counter, self) @@ -402,18 +479,18 @@ def _retry(*args, **kwargs): % (kwargs.keys(),) ) - def _reset(self): + def _reset(self) -> None: """Resets a variety of client states for a new connection.""" - self._queue = deque() - self._pending = deque() + self._queue: Deque[tuple[Any, IAsyncResult]] = deque() + self._pending: Deque[tuple[Any, IAsyncResult, int]] = deque() self._reset_watchers() self._reset_session() self.last_zxid = 0 - self._protocol_version = None + self._protocol_version: int | None = None - def _reset_watchers(self): - watchers = [] + def _reset_watchers(self) -> None: + watchers: list[WatchFunc] = [] for child_watchers in self._child_watchers.values(): watchers.extend(child_watchers) @@ -427,12 +504,12 @@ def _reset_watchers(self): for watch in watchers: self.handler.dispatch_callback(Callback("watch", watch, (ev,))) - def _reset_session(self): + def _reset_session(self) -> None: self._session_id = None self._session_passwd = b"\x00" * 16 @property - def client_state(self): + def client_state(self) -> KeeperState: """Returns the last Zookeeper client state This is the non-simplified state information and is generally @@ -442,7 +519,7 @@ def client_state(self): return self._state @property - def client_id(self): + def client_id(self) -> tuple[int | None, bytes] | None: """Returns the client id for this Zookeeper session if connected. @@ -455,12 +532,16 @@ def client_id(self): return None @property - def connected(self): + def connected(self) -> bool: """Returns whether the Zookeeper connection has been established.""" return self._live.is_set() - def set_hosts(self, hosts, randomize_hosts=None): + def set_hosts( + self, + hosts: str | list[str], + randomize_hosts: bool | None = None, + ) -> None: """sets the list of hosts used by this client. This function accepts the same format hosts parameter as the init @@ -504,7 +585,7 @@ def set_hosts(self, hosts, randomize_hosts=None): self.chroot = new_chroot - def add_listener(self, listener): + def add_listener(self, listener: ListenerFunc) -> None: """Add a function to be called for connection state changes. This function will be called with a @@ -519,15 +600,20 @@ def add_listener(self, listener): should be used so that the listener can return immediately. """ - if not (listener and callable(listener)): + # This check should be unnecessary but protects against people who are + # not using type checkers and accidentally passing in something that + # isn't callable. It should be removed. + if not ( + listener and callable(listener) # type: ignore[truthy-function] + ): raise ConfigurationError("listener must be callable") self.state_listeners.add(listener) - def remove_listener(self, listener): + def remove_listener(self, listener: ListenerFunc) -> None: """Remove a listener function""" self.state_listeners.discard(listener) - def _make_state_change(self, state): + def _make_state_change(self, state: KazooState) -> None: # skip if state is current if self.state == state: return @@ -544,7 +630,7 @@ def _make_state_change(self, state): except Exception: self.logger.exception("Error in connection state listener") - def _session_callback(self, state): + def _session_callback(self, state: KeeperState) -> None: if state == self._state: return @@ -581,9 +667,10 @@ def _session_callback(self, state): self._make_state_change(KazooState.SUSPENDED) self._reset_watchers() - def _notify_pending(self, state): + def _notify_pending(self, state: KeeperState) -> None: """Used to clear a pending response queue and request queue during connection drops.""" + exc: KazooException if state == KeeperState.AUTH_FAILED: exc = AuthFailedError() elif state == KeeperState.EXPIRED_SESSION: @@ -607,7 +694,7 @@ def _notify_pending(self, state): except IndexError: break - def _safe_close(self): + def _safe_close(self) -> None: self.handler.stop() timeout = self._session_timeout // 1000 if timeout < 10: @@ -618,7 +705,9 @@ def _safe_close(self): "and wouldn't close after %s seconds" % timeout ) - def _call(self, request, async_object): + def _call( + self, request: object, async_object: IAsyncResult + ) -> bool | None: """Ensure the client is in CONNECTED or SUSPENDED state and put the request in the queue if it is. @@ -647,14 +736,16 @@ def _call(self, request, async_object): async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None try: write_sock.send(b"\0") except: # NOQA async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None - def start(self, timeout=15): + def start(self, timeout: float = 15.0) -> None: """Initiate connection to ZK. :param timeout: Time in seconds to wait for connection to @@ -678,7 +769,7 @@ def start(self, timeout=15): "should be created before normal use." ) - def start_async(self): + def start_async(self) -> Event: """Asynchronously initiate connection to ZK. :returns: An event object that can be checked to see if the @@ -705,7 +796,7 @@ def start_async(self): self._connection.start() return self._live - def stop(self): + def stop(self) -> None: """Gracefully stop this Zookeeper session. This method can be called while a reconnection attempt is in @@ -721,18 +812,22 @@ def stop(self): return self._stopped.set() - self._queue.append((CloseInstance, None)) + self._queue.append((CloseInstance, cast("IAsyncResult", None))) try: - self._connection._write_sock.send(b"\0") + # This assert should never fail since the connection should + # have been started but I'm not sure how to persaude mypy of that + self._connection._write_sock.send( # type: ignore[union-attr] + b"\0" + ) finally: self._safe_close() - def restart(self): + def restart(self) -> None: """Stop and restart the Zookeeper session.""" self.stop() self.start() - def close(self): + def close(self) -> None: """Free any resources held by the client. This method should be called on a stopped client before it is @@ -742,7 +837,7 @@ def close(self): """ self._connection.close() - def command(self, cmd=b"ruok"): + def command(self, cmd: bytes = b"ruok") -> str: """Sent a management command to the current ZK server. Examples are `ruok`, `envi` or `stat`. @@ -761,8 +856,18 @@ def command(self, cmd=b"ruok"): if not self._live.is_set(): raise ConnectionLoss("No connection to server") - peer = self._connection._socket.getpeername()[:2] - peer_host = self._connection._socket.getpeername()[1] + # Need a way of persauding mypy that the connection is live and thus + # the socket is not None + peer = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + :2 + ] + ) + peer_host = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + 1 + ] + ) sock = self.handler.create_connection( peer, hostname=peer_host, @@ -780,7 +885,7 @@ def command(self, cmd=b"ruok"): sock.close() return result.decode("utf-8", "replace") - def server_version(self, retries=3): + def server_version(self, retries: int = 3) -> tuple[int, ...]: """Get the version of the currently connected ZK server. :returns: The server version, for example (3, 4, 3). @@ -790,7 +895,7 @@ def server_version(self, retries=3): """ - def _try_fetch(): + def _try_fetch() -> tuple[int, ...] | None: data = self.command(b"envi") data_parsed = {} for line in data.splitlines(): @@ -804,13 +909,19 @@ def _try_fetch(): if k: data_parsed[k] = v version = data_parsed.get(ENVI_VERSION_KEY, "") - version_digits = ENVI_VERSION.match(version).group(1) + # FIXME If you get an unexpected answer, you'll crash - not + # changing the code, so just ignoring the type error + version_digits = ENVI_VERSION.match( + version + ).group( # type: ignore[union-attr] + 1 + ) try: return tuple([int(d) for d in version_digits.split(".")]) except ValueError: return None - def _is_valid(version): + def _is_valid(version: tuple[int, ...] | None) -> bool: # All zookeeper versions should have at least major.minor # version numbers; if we get one that doesn't it is likely not # correct and was truncated... @@ -818,21 +929,29 @@ def _is_valid(version): return True return False + # FIXME A better way of doing this would be to put the initial + # _try_fetch in the loop and inline _is_valid but I want to minimise + # code changes + # Try 1 + retries amount of times to get a version that we know # will likely be acceptable... version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + # and the next 2 suppress should include return-value + # but hound is broken + return version # type: ignore for _i in range(0, retries): version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + return version # type: ignore raise KazooException( "Unable to fetch useable server" " version after trying %s times" % (1 + max(0, retries)) ) - def add_auth(self, scheme, credential): + def add_auth(self, scheme: str, credential: str) -> bool: """Send credentials to server. :param scheme: authentication scheme (default supported: @@ -847,9 +966,9 @@ def add_auth(self, scheme, credential): the session state will be set to AUTH_FAILED as well. """ - return self.add_auth_async(scheme, credential).get() + return cast("bool", self.add_auth_async(scheme, credential).get()) - def add_auth_async(self, scheme, credential): + def add_auth_async(self, scheme: str, credential: str) -> IAsyncResult: """Asynchronously send credentials to server. Takes the same arguments as :meth:`add_auth`. @@ -868,7 +987,7 @@ def add_auth_async(self, scheme, credential): self._call(Auth(0, scheme, credential), async_result) return async_result - def unchroot(self, path): + def unchroot(self, path: str) -> str: """Strip the chroot if applicable from the path.""" if not self.chroot: return path @@ -879,7 +998,7 @@ def unchroot(self, path): else: return path - def sync_async(self, path): + def sync_async(self, path: str) -> IAsyncResult: """Asynchronous sync. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -888,10 +1007,10 @@ def sync_async(self, path): async_result = self.handler.async_result() @wrap(async_result) - def _sync_completion(result): + def _sync_completion(result: IAsyncResult) -> str: return self.unchroot(result.get()) - def _do_sync(): + def _do_sync() -> None: result = self.handler.async_result() self._call(Sync(_prefix_root(self.chroot, path)), result) result.rawlink(_sync_completion) @@ -899,7 +1018,7 @@ def _do_sync(): _do_sync() return async_result - def sync(self, path): + def sync(self, path: str) -> str: """Sync, blocks until response is acknowledged. Flushes channel between process and leader. @@ -913,18 +1032,44 @@ def sync(self, path): .. versionadded:: 0.5 """ - return self.sync_async(path).get() + return cast("str", self.sync_async(path).get()) + @overload + def create( + self, + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[False] = False, + ) -> str: + ... + + @overload def create( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[True] = True, + ) -> tuple[str, ZnodeStat]: + ... + + def create( + self, + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> str | tuple[str, ZnodeStat]: """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1003,26 +1148,29 @@ def create( The `include_data` option. """ acl = acl or self.default_acl - return self.create_async( - path, - value, - acl=acl, - ephemeral=ephemeral, - sequence=sequence, - makepath=makepath, - include_data=include_data, - ).get() + return cast( + "str | tuple[str, ZnodeStat]", + self.create_async( + path, + value, + acl=acl, + ephemeral=ephemeral, + sequence=sequence, + makepath=makepath, + include_data=include_data, + ).get(), + ) def create_async( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously create a ZNode. Takes the same arguments as :meth:`create`. @@ -1066,11 +1214,16 @@ def create_async( async_result = self.handler.async_result() @capture_exceptions(async_result) - def do_create(): + def do_create() -> None: result = self._create_async_inner( path, value, - acl, + # The way acl is constructed ends up confusing mypy, which + # thinks that acl can be None here, even though the code + # above ensures that if acl is None, it gets set to + # OPEN_ACL_UNSAFE, so we ignore the type error here. + # behaves differently in python3.8 and python3.14, sigh. + acl, # type: ignore[arg-type] flags, trailing=sequence, include_data=include_data, @@ -1078,12 +1231,14 @@ def do_create(): result.rawlink(create_completion) @capture_exceptions(async_result) - def retry_completion(result): + def retry_completion(result: IAsyncResult) -> None: result.get() do_create() @wrap(async_result) - def create_completion(result): + def create_completion( + result: IAsyncResult, + ) -> str | tuple[str, ZnodeStat] | None: try: if include_data: new_path, stat = result.get() @@ -1098,18 +1253,22 @@ def create_completion(result): else: parent, _ = split(path) self.ensure_path_async(parent, acl).rawlink(retry_completion) + return None do_create() return async_result def _create_async_inner( - self, path, value, acl, flags, trailing=False, include_data=False - ): + self, + path: str, + value: bytes, + acl: Sequence[ACL], + flags: int, + trailing: bool = False, + include_data: bool = False, + ) -> IAsyncResult: async_result = self.handler.async_result() - if include_data: - opcode = Create2 - else: - opcode = Create + opcode = Create2 if include_data else Create call_result = self._call( opcode( @@ -1126,19 +1285,24 @@ def _create_async_inner( # exception upwards to the do_create function in # KazooClient.create so that it gets set on the correct # async_result object - raise async_result.exception + # Note: Do we actually need call_result? It seems like we could + # just check the state of the exception, and avoid the typing + # stuff. + raise async_result.exception # type: ignore[misc] return async_result - def ensure_path(self, path, acl=None): + def ensure_path(self, path: str, acl: Sequence[ACL] | None = None) -> bool: """Recursively create a path if it doesn't exist. :param path: Path of node. :param acl: Permissions for node. """ - return self.ensure_path_async(path, acl).get() + return cast("bool", self.ensure_path_async(path, acl).get()) - def ensure_path_async(self, path, acl=None): + def ensure_path_async( + self, path: str, acl: Sequence[ACL] | None = None + ) -> IAsyncResult: """Recursively create a path asynchronously if it doesn't exist. Takes the same arguments as :meth:`ensure_path`. @@ -1151,19 +1315,21 @@ def ensure_path_async(self, path, acl=None): async_result = self.handler.async_result() @wrap(async_result) - def create_completion(result): + def create_completion(result: IAsyncResult) -> bool: try: - return result.get() + return cast("bool", result.get()) except NodeExistsError: return True @capture_exceptions(async_result) - def prepare_completion(next_path, result): + def prepare_completion(next_path: str, result: IAsyncResult) -> None: result.get() self.create_async(next_path, acl=acl).rawlink(create_completion) @wrap(async_result) - def exists_completion(path, result): + def exists_completion( + path: str, result: IAsyncResult + ) -> Literal[True] | None: if result.get(): return True parent, node = split(path) @@ -1173,12 +1339,15 @@ def exists_completion(path, result): ) else: self.create_async(path, acl=acl).rawlink(create_completion) + return None self.exists_async(path).rawlink(partial(exists_completion, path)) return async_result - def exists(self, path, watch=None): + def exists( + self, path: str, watch: WatchFunc | None = None + ) -> ZnodeStat | None: """Check if a node exists. If a watch is provided, it will be left on the node with the @@ -1198,9 +1367,13 @@ def exists(self, path, watch=None): returns a non-zero error code. """ - return self.exists_async(path, watch=watch).get() + return cast( + "ZnodeStat | None", self.exists_async(path, watch=watch).get() + ) - def exists_async(self, path, watch=None): + def exists_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously check if a node exists. Takes the same arguments as :meth:`exists`. @@ -1218,7 +1391,9 @@ def exists_async(self, path, watch=None): ) return async_result - def get(self, path, watch=None): + def get( + self, path: str, watch: WatchFunc | None = None + ) -> tuple[bytes, ZnodeStat]: """Get the value of a node. If a watch is provided, it will be left on the node with the @@ -1241,9 +1416,13 @@ def get(self, path, watch=None): returns a non-zero error code """ - return self.get_async(path, watch=watch).get() + return cast( + "tuple[bytes, ZnodeStat]", self.get_async(path, watch=watch).get() + ) - def get_async(self, path, watch=None): + def get_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously get the value of a node. Takes the same arguments as :meth:`get`. @@ -1261,7 +1440,12 @@ def get_async(self, path, watch=None): ) return async_result - def get_children(self, path, watch=None, include_data=False): + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> list[str]: """Get a list of child nodes of a path. If a watch is provided it will be left on the node with the @@ -1295,11 +1479,19 @@ def get_children(self, path, watch=None, include_data=False): The `include_data` option. """ - return self.get_children_async( - path, watch=watch, include_data=include_data - ).get() + return cast( + "list[str]", + self.get_children_async( + path, watch=watch, include_data=include_data + ).get(), + ) - def get_children_async(self, path, watch=None, include_data=False): + def get_children_async( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously get a list of child nodes of a path. Takes the same arguments as :meth:`get_children`. @@ -1314,6 +1506,8 @@ def get_children_async(self, path, watch=None, include_data=False): raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() + # FIXME? Do this as req = getc2 if include_data else getc + req: GetChildren | GetChildren2 if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1321,7 +1515,7 @@ def get_children_async(self, path, watch=None, include_data=False): self._call(req, async_result) return async_result - def get_acls(self, path): + def get_acls(self, path: str) -> tuple[list[ACL], ZnodeStat]: """Return the ACL and stat of the node of the given path. :param path: Path of the node. @@ -1339,9 +1533,11 @@ def get_acls(self, path): .. versionadded:: 0.5 """ - return self.get_acls_async(path).get() + return cast( + "tuple[list[ACL], ZnodeStat]", self.get_acls_async(path).get() + ) - def get_acls_async(self, path): + def get_acls_async(self, path: str) -> IAsyncResult: """Return the ACL and stat of the node of the given path. Takes the same arguments as :meth:`get_acls`. @@ -1355,7 +1551,9 @@ def get_acls_async(self, path): self._call(GetACL(_prefix_root(self.chroot, path)), async_result) return async_result - def set_acls(self, path, acls, version=-1): + def set_acls( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> ZnodeStat: """Set the ACL for the node of the given path. Set the ACL for the node of the given path if such a node @@ -1382,9 +1580,13 @@ def set_acls(self, path, acls, version=-1): .. versionadded:: 0.5 """ - return self.set_acls_async(path, acls, version).get() + return cast( + "ZnodeStat", self.set_acls_async(path, acls, version).get() + ) - def set_acls_async(self, path, acls, version=-1): + def set_acls_async( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> IAsyncResult: """Set the ACL for the node of the given path. Takes the same arguments as :meth:`set_acls`. @@ -1407,7 +1609,9 @@ def set_acls_async(self, path, acls, version=-1): ) return async_result - def set(self, path, value, version=-1): + def set( + self, path: str, value: bytes | None, version: int = -1 + ) -> ZnodeStat: """Set the value of a node. If the version of the node being updated is newer than the @@ -1440,9 +1644,11 @@ def set(self, path, value, version=-1): returns a non-zero error code. """ - return self.set_async(path, value, version).get() + return cast("ZnodeStat", self.set_async(path, value, version).get()) - def set_async(self, path, value, version=-1): + def set_async( + self, path: str, value: bytes | None, version: int = -1 + ) -> IAsyncResult: """Set the value of a node. Takes the same arguments as :meth:`set`. @@ -1463,7 +1669,7 @@ def set_async(self, path, value, version=-1): ) return async_result - def transaction(self): + def transaction(self) -> TransactionRequest: """Create and return a :class:`TransactionRequest` object Creates a :class:`TransactionRequest` object. A Transaction can @@ -1480,7 +1686,12 @@ def transaction(self): """ return TransactionRequest(self) - def delete(self, path, version=-1, recursive=False): + # This should not return anything. No return value is documented, and the + # two called functions return different things. AFAICT. But for now I + # want to minimise code changes + def delete( + self, path: str, version: int = -1, recursive: bool = False + ) -> Any: """Delete a node. The call will succeed if such a node exists, and the given @@ -1518,7 +1729,7 @@ def delete(self, path, version=-1, recursive=False): else: return self.delete_async(path, version).get() - def delete_async(self, path, version=-1): + def delete_async(self, path: str, version: int = -1) -> IAsyncResult: """Asynchronously delete a node. Takes the same arguments as :meth:`delete`, with the exception of `recursive`. @@ -1535,7 +1746,7 @@ def delete_async(self, path, version=-1): ) return async_result - def _delete_recursive(self, path): + def _delete_recursive(self, path: str) -> Literal[True] | None: try: children = self.get_children(path) except NoNodeError: @@ -1553,8 +1764,15 @@ def _delete_recursive(self, path): self.delete(path) except NoNodeError: # pragma: nocover pass + return None - def reconfig(self, joining, leaving, new_members, from_config=-1): + def reconfig( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int = -1, + ) -> tuple[bytes, ZnodeStat]: """Reconfig a cluster. This call will succeed if the cluster was reconfigured accordingly. @@ -1625,9 +1843,15 @@ def reconfig(self, joining, leaving, new_members, from_config=-1): result = self.reconfig_async( joining, leaving, new_members, from_config ) - return result.get() + return cast("tuple[bytes, ZnodeStat]", result.get()) - def reconfig_async(self, joining, leaving, new_members, from_config): + def reconfig_async( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int, + ) -> IAsyncResult: """Asynchronously reconfig a cluster. Takes the same arguments as :meth:`reconfig`. @@ -1674,14 +1898,19 @@ class TransactionRequest(object): """ - def __init__(self, client): + def __init__(self, client: KazooClient): self.client = client - self.operations = [] + self.operations: list[Any] = [] self.committed = False def create( - self, path, value=b"", acl=None, ephemeral=False, sequence=False - ): + self, + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + ) -> None: """Add a create ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.create`, with the exception of `makepath`. @@ -1718,7 +1947,7 @@ def create( None, ) - def delete(self, path, version=-1): + def delete(self, path: str, version: int = -1) -> None: """Add a delete ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.delete`, with the exception of `recursive`. @@ -1730,7 +1959,7 @@ def delete(self, path, version=-1): raise TypeError("Invalid type for 'version' (int expected)") self._add(Delete(_prefix_root(self.client.chroot, path), version)) - def set_data(self, path, value, version=-1): + def set_data(self, path: str, value: bytes, version: int = -1) -> None: """Add a set ZNode value to the transaction. Takes the same arguments as :meth:`KazooClient.set`. @@ -1745,7 +1974,7 @@ def set_data(self, path, value, version=-1): SetData(_prefix_root(self.client.chroot, path), value, version) ) - def check(self, path, version): + def check(self, path: str, version: int) -> None: """Add a Check Version to the transaction. This command will fail and abort a transaction if the path @@ -1760,7 +1989,7 @@ def check(self, path, version): CheckVersion(_prefix_root(self.client.chroot, path), version) ) - def commit_async(self): + def commit_async(self) -> IAsyncResult: """Commit the transaction asynchronously. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -1772,28 +2001,38 @@ def commit_async(self): self.client._call(Transaction(self.operations), async_object) return async_object - def commit(self): + def commit(self) -> list[Any]: """Commit the transaction. :returns: A list of the results for each operation in the transaction. """ - return self.commit_async().get() + return cast("list[Any]", self.commit_async().get()) - def __enter__(self): + def __enter__(self) -> TransactionRequest: return self - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: """Commit and cleanup accumulated transaction data.""" if not exc_type: self.commit() + return None - def _check_tx_state(self): + def _check_tx_state(self) -> None: if self.committed: raise ValueError("Transaction already committed") - def _add(self, request, post_processor=None): + def _add( + self, + request: Any, + post_processor: Callable[[Any], Any] | None = None, + ) -> None: self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) self.operations.append(request) diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index b24c697c..ae9341df 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -1,6 +1,9 @@ """Kazoo Exceptions""" +from __future__ import annotations + from collections import defaultdict +from typing import Callable, Type class KazooException(Exception): @@ -51,17 +54,25 @@ class SASLException(KazooException): """ -def _invalid_error_code(): +def _invalid_error_code() -> Type[ZookeeperError]: raise RuntimeError("Invalid error code") -EXCEPTIONS = defaultdict(_invalid_error_code) +EXCEPTIONS: defaultdict[int, Type[ZookeeperError]] = defaultdict( + _invalid_error_code +) -def _zookeeper_exception(code): - def decorator(klass): +def _zookeeper_exception( + code: int, +) -> Callable[[Type[ZookeeperError]], Type[ZookeeperError]]: + def decorator(klass: Type[ZookeeperError]) -> Type[ZookeeperError]: EXCEPTIONS[code] = klass - klass.code = code + # Unfortunately there is currently no good of doing the assignment here + # in a way that type checkers would allow. It's a known problem (see + # https://discuss.python.org/t/how-to-type-hint-a-class-decorator/63010 + # ) + klass.code = code # type: ignore[attr-defined] return klass return decorator diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 8869cc57..ee8ed47d 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -1,10 +1,14 @@ """A eventlet based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import contextlib import logging +from typing import cast, Any, Generator, TYPE_CHECKING + import eventlet from eventlet.green import socket as green_socket from eventlet.green import time as green_time @@ -15,6 +19,18 @@ from kazoo.handlers import utils from kazoo.handlers.utils import selector_select +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + IHandler, + Lockable, + ReentrantLock, + Socket, + ) + from kazoo.protocol.states import Callback + + LOG = logging.getLogger(__name__) # sentinel objects @@ -22,17 +38,17 @@ @contextlib.contextmanager -def _yield_before_after(): +def _yield_before_after() -> Generator[None, None, None]: # Yield to any other co-routines... # # See: http://eventlet.net/doc/modules/greenthread.html # for how this zero sleep is really a cooperative yield to other potential # co-routines... - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] try: yield finally: - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] class TimeoutError(Exception): @@ -42,12 +58,15 @@ class TimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: IHandler): super(AsyncResult, self).__init__( - handler, green_threading.Condition, TimeoutError + handler, + green_threading.Condition, # type: ignore[attr-defined] + TimeoutError, ) +# FIXME This should inherit from IHandler class SequentialEventletHandler(object): """Eventlet handler for sequentially executing callbacks. @@ -81,26 +100,35 @@ class SequentialEventletHandler(object): queue_impl = green_queue.LightQueue queue_empty = green_queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialEventletHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() - self._workers = [] + self.callback_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self.completion_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self._workers: list[ # type: ignore[name-defined] + tuple[ + eventlet.GreenThread, + green_queue.LightQueue, + ] + ] = [] self._started = False @staticmethod - def sleep_func(wait): - green_time.sleep(wait) + def sleep_func(wait: float) -> None: + green_time.sleep(wait) # type: ignore[attr-defined, no-untyped-call] @property - def running(self): + def running(self) -> bool: return self._started timeout_exception = TimeoutError - def _process_completion_queue(self): + def _process_completion_queue(self) -> None: while True: - cb = self.completion_queue.get() + cb = self.completion_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -114,9 +142,9 @@ def _process_completion_queue(self): finally: del cb # release before possible idle - def _process_callback_queue(self): + def _process_callback_queue(self) -> None: while True: - cb = self.callback_queue.get() + cb = self.callback_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -130,58 +158,83 @@ def _process_callback_queue(self): finally: del cb # release before possible idle - def start(self): + def start(self) -> None: if not self._started: # Spawn our worker threads, we have # - A callback worker for watch events to be called # - A completion worker for completion events to be called - w = eventlet.spawn(self._process_completion_queue) + w = eventlet.spawn( + self._process_completion_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.completion_queue)) - w = eventlet.spawn(self._process_callback_queue) + w = eventlet.spawn( + self._process_callback_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.callback_queue)) self._started = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: while self._workers: w, q = self._workers.pop() - q.put(_STOP) + q.put(_STOP) # type: ignore[no-untyped-call] w.wait() self._started = False atexit.unregister(self.stop) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: return utils.create_tcp_socket(green_socket) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(green_socket) - def event_object(self): - return green_threading.Event() + def event_object(self) -> Event: + return cast( + "Event", green_threading.Event() # type: ignore[attr-defined] + ) - def lock_object(self): - return green_threading.Lock() + def lock_object(self) -> Lockable: + return cast( + "Lockable", green_threading.Lock() # type: ignore[attr-defined] + ) - def rlock_object(self): - return green_threading.RLock() + def rlock_object(self) -> ReentrantLock: + return cast( + "ReentrantLock", + green_threading.RLock(), # type: ignore[attr-defined] + ) - def create_connection(self, *args, **kwargs): + # FIXME fix parameters + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) - def select(self, *args, **kwargs): + # FIXME fix parameters + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: with _yield_before_after(): + # Following appears to be a bug in mypy (see + # https://github.com/python/mypy/issues/6799) return selector_select( - *args, selectors_module=green_selectors, **kwargs + *args, + selectors_module=green_selectors, # type: ignore[misc] + **kwargs, ) - def async_result(self): - return AsyncResult(self) + def async_result(self) -> AsyncResult: + return AsyncResult(self) # type: ignore[arg-type] - def spawn(self, func, *args, **kwargs): - t = green_threading.Thread(target=func, args=args, kwargs=kwargs) + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> green_threading.Thread: # type: ignore[name-defined] + t = green_threading.Thread( # type: ignore[attr-defined] + target=func, args=args, kwargs=kwargs + ) t.daemon = True t.start() return t - def dispatch_callback(self, callback): - self.callback_queue.put(lambda: callback.func(*callback.args)) + def dispatch_callback(self, callback: Callback) -> None: + self.callback_queue.put( # type: ignore[no-untyped-call] + lambda: callback.func(*callback.args) + ) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index f36389aa..7a3c215c 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -1,9 +1,13 @@ """A gevent based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import logging +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + import gevent from gevent import socket import gevent.event @@ -14,18 +18,30 @@ from kazoo.handlers.utils import selector_select +from gevent import Greenlet from gevent.lock import Semaphore, RLock from kazoo.handlers import utils +if TYPE_CHECKING: + from kazoo.interfaces import FdLike, Lockable, Socket + from kazoo.protocol.states import Callback + _using_libevent = gevent.__version__.startswith("0.") log = logging.getLogger(__name__) + _STOP = object() AsyncResult = gevent.event.AsyncResult +# The following would be great typenames, but python3.8 complains about type +# objects not being indexable. +# GCallback = Callable[..., None] +# Worker = Greenlet[..., Any] +# CallbackQueue = gevent.queue.Queue[Callable[..., None]] + class SequentialGeventHandler(object): """Gevent handler for sequentially executing callbacks. @@ -53,24 +69,28 @@ class SequentialGeventHandler(object): queue_empty = gevent.queue.Empty sleep_func = staticmethod(gevent.sleep) - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialGeventHandler` instance""" - self.callback_queue = self.queue_impl() + self.callback_queue: gevent.queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._async = None self._state_change = Semaphore() - self._workers = [] + self._workers: list[Greenlet[..., Any]] = [] @property - def running(self): + def running(self) -> bool: return self._running class timeout_exception(gevent.Timeout): - def __init__(self, msg): + def __init__(self, msg: Any): gevent.Timeout.__init__(self, exception=msg) - def _create_greenlet_worker(self, queue): - def greenlet_worker(): + def _create_greenlet_worker( + self, queue: gevent.queue.Queue[Callable[..., None]] + ) -> Greenlet[..., Any]: + def greenlet_worker() -> None: while True: try: func = queue.get() @@ -88,7 +108,7 @@ def greenlet_worker(): return gevent.spawn(greenlet_worker) - def start(self): + def start(self) -> None: """Start the greenlet workers.""" with self._state_change: if self._running: @@ -98,12 +118,13 @@ def start(self): # Spawn our worker greenlets, we have # - A callback worker for watch events to be called + # FIXME Why the loop? for queue in (self.callback_queue,): w = self._create_greenlet_worker(queue) self._workers.append(w) atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the greenlet workers and empty all queues.""" with self._state_change: if not self._running: @@ -112,7 +133,7 @@ def stop(self): self._running = False for queue in (self.callback_queue,): - queue.put(_STOP) + queue.put(cast("Callable[..., None]", _STOP)) while self._workers: worker = self._workers.pop() @@ -123,33 +144,45 @@ def stop(self): atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: + # FIXME use the correct arguments, not *args, *kwargs return selector_select( - *args, selectors_module=gevent.selectors, **kwargs + # Likely a bug in mypy (see + # https://github.com/python/mypy/issues/6799) + *args, + selectors_module=gevent.selectors, + **kwargs, # type: ignore[misc] ) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: + # See above return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + # See above return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> gevent.event.Event: """Create an appropriate Event object""" return gevent.event.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create an appropriate Lock object""" - return gevent.thread.allocate_lock() + return cast( + "Lockable", + gevent.thread.allocate_lock(), # type: ignore[no-untyped-call] + ) - def rlock_object(self): + def rlock_object(self) -> RLock: """Create an appropriate RLock object""" return RLock() - def async_result(self): + def async_result(self) -> AsyncResult[Any]: """Create a :class:`AsyncResult` instance The :class:`AsyncResult` instance will have its completion @@ -160,11 +193,13 @@ def async_result(self): """ return AsyncResult() - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> gevent.Greenlet[..., Any]: """Spawn a function to run asynchronously""" return gevent.spawn(func, *args, **kwargs) - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index b9acd875..829a7010 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -10,6 +10,8 @@ :class:`~kazoo.handlers.gevent.SequentialGeventHandler` instead. """ + +from __future__ import annotations from __future__ import absolute_import import atexit @@ -19,9 +21,22 @@ import threading import time +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + from kazoo.handlers import utils from kazoo.handlers.utils import selector_select - +from kazoo.interfaces import IHandler + +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + Lockable, + ReentrantLock, + Socket, + SpawnedFunc, + ) + from kazoo.protocol.states import Callback # sentinel objects _STOP = object() @@ -29,7 +44,7 @@ log = logging.getLogger(__name__) -def _to_fileno(obj): +def _to_fileno(obj: FdLike) -> int: if isinstance(obj, int): fd = int(obj) elif hasattr(obj, "fileno"): @@ -55,13 +70,13 @@ class KazooTimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: Any) -> None: super(AsyncResult, self).__init__( handler, threading.Condition, KazooTimeoutError ) -class SequentialThreadingHandler(object): +class SequentialThreadingHandler(IHandler): """Threading handler for sequentially executing callbacks. This handler executes callbacks in a sequential manner. A queue is @@ -96,20 +111,26 @@ class SequentialThreadingHandler(object): queue_impl = queue.Queue queue_empty = queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialThreadingHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() + self.callback_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() + self.completion_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._state_change = threading.Lock() - self._workers = [] + self._workers: list[threading.Thread] = [] @property - def running(self): + def running(self) -> bool: return self._running - def _create_thread_worker(self, work_queue): - def _thread_worker(): # pragma: nocover + def _create_thread_worker( + self, work_queue: queue.Queue[Callable[..., None]] + ) -> threading.Thread: + def _thread_worker() -> None: # pragma: nocover while True: try: func = work_queue.get() @@ -128,7 +149,7 @@ def _thread_worker(): # pragma: nocover t = self.spawn(_thread_worker) return t - def start(self): + def start(self) -> None: """Start the worker threads.""" with self._state_change: if self._running: @@ -143,7 +164,7 @@ def start(self): self._running = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the worker threads and empty all queues.""" with self._state_change: if not self._running: @@ -152,7 +173,7 @@ def stop(self): self._running = False for work_queue in (self.completion_queue, self.callback_queue): - work_queue.put(_STOP) + work_queue.put(cast("Callable[..., None]", _STOP)) self._workers.reverse() while self._workers: @@ -164,41 +185,47 @@ def stop(self): self.completion_queue = self.queue_impl() atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: return selector_select(*args, **kwargs) - def socket(self): + def socket(self) -> Socket: return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> Event: """Create an appropriate Event object""" return threading.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create a lock object""" - return threading.Lock() + # Note: This is not ideal, but the ContextManager Protocol seems to + # think you should return an object of the same type. + return cast("Lockable", threading.Lock()) - def rlock_object(self): + def rlock_object(self) -> ReentrantLock: """Create an appropriate RLock object""" - return threading.RLock() + return cast("ReentrantLock", threading.RLock()) - def async_result(self): + def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance""" return AsyncResult(self) - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> threading.Thread: t = threading.Thread(target=func, args=args, kwargs=kwargs) t.daemon = True t.start() return t - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 206806f6..019cc2d2 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -1,5 +1,7 @@ """Kazoo handler helpers""" +from __future__ import annotations + from collections import defaultdict import errno import functools @@ -8,6 +10,14 @@ import ssl import socket import time +from types import ModuleType +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING +from typing_extensions import ParamSpec + +from kazoo.interfaces import IAsyncResult, FdLike + +if TYPE_CHECKING: + from kazoo.interfaces import Socket HAS_FNCTL = True try: @@ -15,36 +25,51 @@ except ImportError: # pragma: nocover HAS_FNCTL = False + # sentinel objects +# Note: This needs to be a unique object that is not None, as None is used to +# indicate a successful result in AsyncResult. +# This should probably be an Enum, it would certainly be cleaner, but don't +# want to change the code too much. _NONE = object() +CallbackFunc = Callable[..., None] -class AsyncResult(object): + +class AsyncResult(IAsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler, condition_factory, timeout_factory): + def __init__( + self, + handler: Any, + condition_factory: Callable[[], Any], + timeout_factory: Callable[[], Any], + ) -> None: self._handler = handler - self._exception = _NONE + self._exception: object | Exception | None = _NONE self._condition = condition_factory() - self._callbacks = [] + self._callbacks: list[CallbackFunc] = [] self._timeout_factory = timeout_factory self.value = None - def ready(self): + def ready(self) -> bool: """Return true if and only if it holds a value or an exception""" return self._exception is not _NONE - def successful(self): + def successful(self) -> bool: """Return true if and only if it is ready and holds a value""" return self._exception is None @property - def exception(self): + def exception(self) -> Exception | None: if self._exception is not _NONE: - return self._exception + # The next line should have return-value, but hound ci + # is frankly nothing but a hound dog + return self._exception # type: ignore + return None - def set(self, value=None): + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters.""" with self._condition: self.value = value @@ -52,14 +77,14 @@ def set(self, value=None): self._do_callbacks() self._condition.notify_all() - def set_exception(self, exception): + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters.""" with self._condition: self._exception = exception self._do_callbacks() self._condition.notify_all() - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception. If there is no value raises TimeoutError. @@ -69,18 +94,18 @@ def get(self, block=True, timeout=None): if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] elif block: self._condition.wait(timeout) if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] # if we get to this point we timeout raise self._timeout_factory() - def get_nowait(self): + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raises TimeoutError @@ -88,14 +113,14 @@ def get_nowait(self): """ return self.get(block=False) - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Block until the instance is ready.""" with self._condition: if not self.ready(): self._condition.wait(timeout) return self._exception is not _NONE - def rawlink(self, callback): + def rawlink(self, callback: CallbackFunc) -> None: """Register a callback to call when a value or an exception is set""" with self._condition: @@ -106,7 +131,7 @@ def rawlink(self, callback): if self.ready(): self._do_callbacks() - def unlink(self, callback): + def unlink(self, callback: CallbackFunc) -> None: """Remove the callback set by :meth:`rawlink`""" with self._condition: if self.ready(): @@ -116,7 +141,7 @@ def unlink(self, callback): if callback in self._callbacks: self._callbacks.remove(callback) - def _do_callbacks(self): + def _do_callbacks(self) -> None: """Execute the callbacks that were registered by :meth:`rawlink`. If the handler is in running state this method only schedules the calls to be performed by the handler. If it's stopped, @@ -131,19 +156,21 @@ def _do_callbacks(self): functools.partial(callback, self)() -def _set_fd_cloexec(fd): +def _set_fd_cloexec(fd: Socket) -> None: flags = fcntl.fcntl(fd, fcntl.F_GETFD) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) -def _set_default_tcpsock_options(module, sock): +def _set_default_tcpsock_options(module: ModuleType, sock: Socket) -> Socket: sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1) if HAS_FNCTL: _set_fd_cloexec(sock) return sock -def create_socket_pair(module, port=0): +def create_socket_pair( + module: ModuleType, port: int = 0 +) -> tuple[Socket, Socket]: """Create socket pair. If socket.socketpair isn't available, we emulate it. @@ -182,32 +209,32 @@ def create_socket_pair(module, port=0): return client_sock, srv_sock -def create_tcp_socket(module): +def create_tcp_socket(module: ModuleType) -> Socket: """Create a TCP socket with the CLOEXEC flag set.""" type_ = module.SOCK_STREAM if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover # if available, set cloexec flag during socket creation type_ |= module.SOCK_CLOEXEC - sock = module.socket(module.AF_INET, type_) + sock: Socket = module.socket(module.AF_INET, type_) _set_default_tcpsock_options(module, sock) return sock def create_tcp_connection( - module, - address, - hostname=None, - timeout=None, - use_ssl=False, - ca=None, - certfile=None, - keyfile=None, - keyfile_password=None, - verify_certs=True, - check_hostname=False, - options=None, - ciphers=None, -): + module: ModuleType, + address: tuple[str, str | int], + hostname: str | None = None, + timeout: float | None = None, + use_ssl: bool = False, + ca: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + verify_certs: bool = True, + check_hostname: bool = False, + options: ssl.Options | None = None, + ciphers: str | None = None, +) -> Socket: end = None if timeout is None: # thanks to create_connection() developers for @@ -215,7 +242,7 @@ def create_tcp_connection( timeout = module.getdefaulttimeout() if timeout is not None: end = time.monotonic() + timeout - sock = None + sock: Socket | None = None while True: timeout_at = end if end is None else end - time.monotonic() @@ -279,7 +306,13 @@ def create_tcp_connection( sock = module.create_connection(address, timeout_at) break except Exception as ex: - errnum = ex.errno if isinstance(ex, OSError) else ex[0] + # Seriously WTF? if ex is an exception, how can it be a tuple? + # I guess gevent can do this, but really... + errnum = ( + ex.errno + if isinstance(ex, OSError) + else ex[0] # type: ignore + ) if errnum == errno.EINTR: continue raise @@ -291,7 +324,16 @@ def create_tcp_connection( return sock -def capture_exceptions(async_result): +CapturedResult = TypeVar("CapturedResult") +GenericArgs = ParamSpec("GenericArgs") + + +def capture_exceptions( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -299,20 +341,30 @@ def capture_exceptions(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @functools.wraps(function) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: try: return function(*args, **kwargs) except Exception as exc: async_result.set_exception(exc) + return None return captured_function return capture -def wrap(async_result): +def wrap( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a non-None return value. @@ -321,9 +373,13 @@ def wrap(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @capture_exceptions(async_result) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: value = function(*args, **kwargs) if value is not None: async_result.set(value) @@ -334,7 +390,7 @@ def captured_function(*args, **kwargs): return capture -def fileobj_to_fd(fileobj): +def fileobj_to_fd(fileobj: FdLike) -> int: """Return a file descriptor from a file object. Parameters: @@ -349,18 +405,25 @@ def fileobj_to_fd(fileobj): if isinstance(fileobj, int): fd = fileobj else: + # FIXME given the protocol I don't think the try/catch/int are + # required. try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise TypeError("Invalid file object: " "{!r}".format(fileobj)) + # FIXME Questionable, just let select deal with it. if fd < 0: raise TypeError("Invalid file descriptor: {}".format(fd)) return fd def selector_select( - rlist, wlist, xlist, timeout=None, selectors_module=selectors -): + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + selectors_module: ModuleType = selectors, +) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: """Selector-based drop-in replacement for select to overcome select limitation on a maximum filehandle value. """ @@ -374,8 +437,8 @@ def selector_select( selectors_module.EVENT_READ: rlist, selectors_module.EVENT_WRITE: wlist, } - fd_events = defaultdict(int) - fd_fileobjs = defaultdict(list) + fd_events: defaultdict[int, int] = defaultdict(int) + fd_fileobjs: defaultdict[int, list[FdLike]] = defaultdict(list) for event, fileobjs in events_mapping.items(): for fileobj in fileobjs: @@ -391,7 +454,9 @@ def selector_select( # gevent can raise OSError raise ValueError("Invalid event mask or fd") from e - revents, wevents, xevents = [], [], [] + revents: list[FdLike] = [] + wevents: list[FdLike] = [] + xevents: list[FdLike] = [] try: ready = selector.select(timeout) finally: diff --git a/kazoo/hosts.py b/kazoo/hosts.py index 3ece9318..cda746a3 100644 --- a/kazoo/hosts.py +++ b/kazoo/hosts.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import urllib.parse -def collect_hosts(hosts): +def collect_hosts( + hosts: str | list[str], +) -> tuple[list[tuple[str, int]], str | None]: """ Collect a set of hosts and an optional chroot from a string or a list of strings. @@ -12,8 +16,8 @@ def collect_hosts(hosts): else: host_ports, chroot = hosts, None else: - host_ports, chroot = hosts.partition("/")[::2] - host_ports = host_ports.split(",") + host_ports_1, chroot = hosts.partition("/")[::2] + host_ports = host_ports_1.split(",") chroot = "/" + chroot if chroot else None result = [] diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 351f1fd8..71de8c21 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -8,10 +8,151 @@ """ +from __future__ import annotations + +import abc +import queue + +from types import TracebackType +from typing import ( + Any, + Callable, + Iterable, + Protocol, + Union, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from kazoo.protocol.states import Callback + # public API -class IHandler(object): +class HasFileNo(Protocol): + """Protocol for objects that support a fileno method.""" + + def fileno(self) -> int: + ... + + +FdLike = Union[int, HasFileNo] + + +class Socket(HasFileNo, Protocol): + """This is for things that provide a socket.socket-like interface. + + This is required because: + 1. The socket in gevent doesn't inherit from socket.socket + 2. mypy gets confused if you have a method called socket and + subsequently attempt to use socket or socket.socket as a return type + """ + + def close(self) -> None: + ... + + def fileno(self) -> int: + ... + + def recv(self, bufsize: int, flags: int = 0) -> bytes: + ... + + def send(self, data: bytes | memoryview, flags: int = 0) -> int: + ... + + def sendall(self, data: bytes, flags: int = 0) -> None: + ... + + def setblocking(self, flags: bool) -> None: + ... + + def setsockopt(self, level: int, optname: int, value: int) -> None: + ... + + def getpeername(self) -> tuple[str, int]: + ... + + +class Lockable(Protocol): + """This is what threading.Lock implements. + + In python 3.9+ it's available natively. Though given it has some + very odd typing, I wouldn't put money on it. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> int | None: + """The gevent release returns an int...""" + ... + + def locked(self) -> bool: + ... + + +class ReentrantLock(Protocol): + """This is what threading.RLock implements. + + In python 3.14+, it's the same as Lock, which adds to the fun. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> None: + ... + + +class Event(Protocol): + """Protocol for threading.Event""" + + def is_set(self) -> bool: + ... + + def set(self) -> None: + ... + + def clear(self) -> None: + ... + + def wait(self, timeout: float | None = None) -> bool: + ... + + +class Threadlike(Protocol): + """Protocol for something like a thread.""" + + def join(self, timeout: float | None = None) -> None: + ... + + +SpawnedFunc = Callable[..., None] + + +class IHandler(abc.ABC): """A Callback Handler for Zookeeper completion and watch callbacks. This object must implement several methods responsible for @@ -44,43 +185,69 @@ class IHandler(object): """ - def start(self): + timeout_exception: type[Exception] = None # type: ignore[assignment] + sleep_func: staticmethod[[float], None] = None # type: ignore[assignment] + queue_impl: type[queue.Queue[Any]] = None # type: ignore[assignment] + + @abc.abstractmethod + def start(self) -> None: """Start the handler, used for setting up the handler.""" - def stop(self): + @abc.abstractmethod + def stop(self) -> None: """Stop the handler. Should block until the handler is safely stopped.""" - def select(self): + @abc.abstractmethod + def select( + self, + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: """A select method that implements Python's select.select API""" - def socket(self): - """A socket method that implements Python's socket.socket + @abc.abstractmethod + def socket(self) -> Socket: + """A socket method that implements Python's socket.socket API""" + + # FIXME This should have a proper set of parameters. + @abc.abstractmethod + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + """A socket method that implements Python's socket.create_connection API""" - def create_connection(self): - """A socket method that implements Python's - socket.create_connection API""" + @abc.abstractmethod + def create_socket_pair(self) -> tuple[Socket, Socket]: + """A socket method that implements Python's socket.socketpair API""" - def event_object(self): + @abc.abstractmethod + def event_object(self) -> Event: """Return an appropriate object that implements Python's threading.Event API""" - def lock_object(self): + @abc.abstractmethod + def lock_object(self) -> Lockable: """Return an appropriate object that implements Python's threading.Lock API""" - def rlock_object(self): + @abc.abstractmethod + def rlock_object(self) -> ReentrantLock: """Return an appropriate object that implements Python's threading.RLock API""" - def async_result(self): + @abc.abstractmethod + def async_result(self) -> IAsyncResult: """Return an instance that conforms to the :class:`~IAsyncResult` interface appropriate for this handler""" - def spawn(self, func, *args, **kwargs): + @abc.abstractmethod + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> Threadlike: """Spawn a function to run asynchronously :param args: args to call the function with. @@ -91,7 +258,8 @@ def spawn(self, func, *args, **kwargs): """ - def dispatch_callback(self, callback): + @abc.abstractmethod + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object :param callback: A :class:`~kazoo.protocol.states.Callback` @@ -100,7 +268,7 @@ def dispatch_callback(self, callback): """ -class IAsyncResult(object): +class IAsyncResult(abc.ABC): """An Async Result object that can be queried for a value that has been set asynchronously. @@ -123,15 +291,18 @@ class IAsyncResult(object): """ - def ready(self): + @abc.abstractmethod + def ready(self) -> bool: """Return `True` if and only if it holds a value or an exception""" - def successful(self): + @abc.abstractmethod + def successful(self) -> bool: """Return `True` if and only if it is ready and holds a value""" - def set(self, value=None): + @abc.abstractmethod + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters. :param value: Value to store as the result. @@ -140,7 +311,8 @@ def set(self, value=None): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def set_exception(self, exception): + @abc.abstractmethod + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters. :param exception: Exception to raise when fetching the value. @@ -149,7 +321,8 @@ def set_exception(self, exception): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def get(self, block=True, timeout=None): + @abc.abstractmethod + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception :param block: Whether this method should block or return @@ -164,13 +337,15 @@ def get(self, block=True, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def get_nowait(self): + @abc.abstractmethod + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raise the Timeout exception class on the associated :class:`IHandler` interface.""" - def wait(self, timeout=None): + @abc.abstractmethod + def wait(self, timeout: float | None = None) -> Any: """Block until the instance is ready. :param timeout: How long to wait for a value when `block` is @@ -182,7 +357,8 @@ def wait(self, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def rawlink(self, callback): + @abc.abstractmethod + def rawlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """Register a callback to call when a value or an exception is set @@ -194,10 +370,17 @@ def rawlink(self, callback): """ - def unlink(self, callback): + @abc.abstractmethod + def unlink(self, callback: Callable[[IAsyncResult], None]) -> None: """Remove the callback set by :meth:`rawlink` :param callback: A callback function to remove. :type callback: func """ + + @property + @abc.abstractmethod + def exception(self) -> Exception | None: + """The exception set by :meth:`set_exception` or `None` if no + exception has been set""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 3df7b162..e4ed1a17 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -1,4 +1,7 @@ """Zookeeper Protocol Connection Handler""" + +from __future__ import annotations + from binascii import hexlify from contextlib import contextmanager import copy @@ -8,6 +11,17 @@ import socket import ssl import time +from typing import ( + Callable, + ContextManager, + Iterator, + Literal, + TypeVar, + TYPE_CHECKING, + cast, + overload, +) +from typing_extensions import Buffer from kazoo.exceptions import ( AuthFailedError, @@ -41,9 +55,14 @@ ) from kazoo.retry import ( ForceRetryError, + KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import Socket, Threadlike + try: import puresasl import puresasl.client @@ -76,7 +95,7 @@ # removed from Python3+ -def buffer(obj, offset=0): +def buffer(obj: Buffer, offset: int = 0) -> memoryview: return memoryview(obj)[offset:] @@ -94,22 +113,32 @@ class RWPinger(object): """ - def __init__(self, hosts, connection_func, socket_handling): + def __init__( + self, + hosts: list[tuple[str, int]], + connection_func: Callable[[tuple[str, int]], Socket], + socket_handling: Callable[[], ContextManager[None]], + ): self.hosts = hosts self.connection = connection_func - self.last_attempt = None + self.last_attempt: float | None = None self.socket_handling = socket_handling - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, int] | Literal[False] | None]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay): + def _next_server( + self, delay: float + ) -> tuple[str, int] | Literal[False] | None: jitter = random.randint(0, 100) / 100.0 - while time.monotonic() < self.last_attempt + delay + jitter: + while ( + time.monotonic() + < self.last_attempt + delay + jitter # type: ignore[operator] + ): # Skip rw ping checks if its too soon return False for host, port in self.hosts: @@ -128,20 +157,36 @@ def _next_server(self, delay): except ConnectionDropped: return False + # NOTE: This does actually look like it's unreachable but I don't + # want to alter the code any more than necessary for the first + # pass. + # The loop is basically a sleep with jitter that can be # Add some jitter between host pings - while time.monotonic() < self.last_attempt + jitter: + while ( # type: ignore[unreachable] + time.monotonic() < self.last_attempt + jitter + ): return False - delay *= 2 + delay *= 2 # And while not unreachable, this is pointless + return None class RWServerAvailable(Exception): """Thrown if a RW Server becomes available""" +ReturnValue = TypeVar("ReturnValue") + + class ConnectionHandler(object): """Zookeeper connection handler""" - def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): + def __init__( + self, + client: KazooClient, + retry_sleeper: KazooRetry, + logger: logging.Logger | None = None, + sasl_options: dict[str, str] | None = None, + ): self.client = client self.handler = client.handler self.retry_sleeper = retry_sleeper @@ -154,15 +199,17 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.connection_stopped.set() self.ping_outstanding = client.handler.event_object() - self._read_sock = None - self._write_sock = None + self._read_sock: Socket | None = None + self._write_sock: Socket | None = None - self._socket = None - self._xid = None - self._rw_server = None - self._ro_mode = False + self._socket: Socket | None = None + self._xid: int | None = None + self._rw_server: tuple[str, int] | None = None + self._ro_mode: Iterator[ + Literal[False] | tuple[str, int] | None + ] | Literal[False] | None = False - self._connection_routine = None + self._connection_routine: Threadlike | None = None self.sasl_options = sasl_options self.sasl_cli = None @@ -170,14 +217,14 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): # This is instance specific to avoid odd thread bug issues in Python # during shutdown global cleanup @contextmanager - def _socket_error_handling(self): + def _socket_error_handling(self) -> Iterator[None]: try: yield except (socket.error, select.error) as e: err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) - def start(self): + def start(self) -> None: """Start the connection up""" if self.connection_closed.is_set(): rw_sockets = self.handler.create_socket_pair() @@ -189,7 +236,7 @@ def start(self): ) self._connection_routine = self.handler.spawn(self.zk_loop) - def stop(self, timeout=None): + def stop(self, timeout: float | None = None) -> bool: """Ensure the writer has stopped, wait to see if it does.""" self.connection_stopped.wait(timeout) if self._connection_routine: @@ -197,7 +244,7 @@ def stop(self, timeout=None): self._connection_routine = None return self.connection_stopped.is_set() - def close(self): + def close(self) -> None: """Release resources held by the connection The connection can be restarted afterwards. @@ -212,7 +259,7 @@ def close(self): if rs is not None: rs.close() - def _server_pinger(self): + def _server_pinger(self) -> RWPinger: """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" return RWPinger( @@ -221,16 +268,21 @@ def _server_pinger(self): self._socket_error_handling, ) - def _read_header(self, timeout): + def _read_header( + self, timeout: float | None + ) -> tuple[ReplyHeader, bytes, int]: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) header, offset = ReplyHeader.deserialize(b, 0) return header, b, offset - def _read(self, length, timeout): + def _read(self, length: int, timeout: float | None) -> bytes: msgparts = [] remaining = length + # We know that self._socket is not None here because we only call + # this method when we have set up the connection and the read socket. + # But mypy doesn't understand that. with self._socket_error_handling(): while remaining > 0: # Because of SSL framing, a select may not return when using @@ -240,11 +292,13 @@ def _read(self, length, timeout): # data from the underlying socket. if ( hasattr(self._socket, "pending") - and self._socket.pending() > 0 + and self._socket.pending() > 0 # type: ignore[union-attr] ): pass else: - s = self.handler.select([self._socket], [], [], timeout)[0] + s = self.handler.select( + [cast("Socket", self._socket)], [], [], timeout + )[0] if not s: # pragma: nocover # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any @@ -252,7 +306,9 @@ def _read(self, length, timeout): "socket time-out during read" ) try: - chunk = self._socket.recv(remaining) + chunk = self._socket.recv( # type: ignore[union-attr] + remaining + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -267,7 +323,24 @@ def _read(self, length, timeout): remaining -= len(chunk) return b"".join(msgparts) - def _invoke(self, timeout, request, xid=None): + @overload + def _invoke( + self, timeout: float | None, request: Connect + ) -> tuple[Connect, int | None]: + ... + + @overload + def _invoke( + self, timeout: float | None, request: Auth, xid: int + ) -> int | None: + ... + + def _invoke( + self, + timeout: float | None, + request: Auth | Connect, + xid: int | None = None, + ) -> tuple[Connect, int | None] | int | None: """A special writer used during connection establishment only""" self._submit(request, timeout, xid) @@ -296,7 +369,9 @@ def _invoke(self, timeout, request, xid=None): if hasattr(request, "deserialize"): try: - obj, _ = request.deserialize(msg, 0) + # This is a bit of an annoying ignore as I've just done a + # hasattr... + obj, _ = request.deserialize(msg, 0) # type:ignore[union-attr] except Exception: self.logger.exception( "Exception raised during deserialization " @@ -311,7 +386,12 @@ def _invoke(self, timeout, request, xid=None): return zxid - def _submit(self, request, timeout, xid=None): + def _submit( + self, + request: Auth | Connect | Ping | SASL, + timeout: float | None, + xid: int | None = None, + ) -> None: """Submit a request object with a timeout value and optional xid""" b = bytearray() @@ -328,13 +408,17 @@ def _submit(self, request, timeout, xid=None): ) self._write(int_struct.pack(len(b)) + b, timeout) - def _write(self, msg, timeout): + def _write(self, msg: bytes, timeout: float | None) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) + # Note: The casts/type: ignore are because mypy can't work out + # self._socket is not None, and I don't want to change any code. with self._socket_error_handling(): while sent < msg_length: - s = self.handler.select([], [self._socket], [], timeout)[1] + s = self.handler.select( + [], [cast("Socket", self._socket)], [], timeout + )[1] if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any @@ -343,7 +427,9 @@ def _write(self, msg, timeout): ) msg_slice = buffer(msg, sent) try: - bytes_sent = self._socket.send(msg_slice) + bytes_sent = self._socket.send( # type:ignore[union-attr] + msg_slice + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -356,14 +442,14 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent - def _read_watch_event(self, buffer, offset): + def _read_watch_event(self, buffer: bytes, offset: int) -> None: client = self.client watch, offset = Watch.deserialize(buffer, offset) path = watch.path self.logger.debug("Received EVENT: %s", watch) - watchers = [] + watchers: list[WatchFunc] = [] if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) @@ -385,10 +471,15 @@ def _read_watch_event(self, buffer, offset): return # Dump the watchers to the watch thread - for watch in watchers: - client.handler.dispatch_callback(Callback("watch", watch, (ev,))) - - def _read_response(self, header, buffer, offset): + for watch1 in watchers: + client.handler.dispatch_callback(Callback("watch", watch1, (ev,))) + + def _read_response( + self, + header: ReplyHeader, + buffer: bytes, + offset: int, + ) -> object | None: client = self.client request, async_object, xid = client._pending.popleft() if header.zxid and header.zxid > 0: @@ -404,7 +495,11 @@ def _read_response(self, header, buffer, offset): # Determine if its an exists request and a no node error exists_error = ( - header.err == NoNodeError.code and request.type == Exists.type + # NoNodeError does actually have a code. It's added by a decorator, + # which could possibly be better done via inheritance but this is + # less invasive to the existing code. + header.err == NoNodeError.code # type: ignore[attr-defined] + and request.type == Exists.type ) # Set the exception if its not an exists error @@ -430,7 +525,7 @@ def _read_response(self, header, buffer, offset): request, ) async_object.set_exception(exc) - return + return None self.logger.debug( "Received response(xid=%s): %r", xid, response ) @@ -452,8 +547,9 @@ def _read_response(self, header, buffer, offset): if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE + return None - def _read_socket(self, read_timeout): + def _read_socket(self, read_timeout: float) -> object | None: """Called when there's something to read on the socket""" client = self.client @@ -476,8 +572,13 @@ def _read_socket(self, read_timeout): self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) + return None - def _send_request(self, read_timeout, connect_timeout): + def _send_request( + self, + read_timeout: float, + connect_timeout: float, + ) -> None: """Called when we have something to send out on the socket""" client = self.client try: @@ -489,7 +590,11 @@ def _send_request(self, read_timeout, connect_timeout): try: # Clear possible inconsistence (no request in the queue # but have data in the read socket), which causes cpu to spin. - self._read_sock.recv(1) + # + # We know _read_sock is not None because we only call this + # method when we have set up the connection and the read + # socket, but mypy doesn't understand that. + self._read_sock.recv(1) # type: ignore[union-attr] except OSError: pass return @@ -505,15 +610,19 @@ def _send_request(self, read_timeout, connect_timeout): if request.type == Auth.type: xid = AUTH_XID else: - self._xid = (self._xid % 2147483647) + 1 + # We must have initialised the xid counter by now + # Might want to consider initialising it to 0 instead of none? + self._xid = (self._xid % 2147483647) + 1 # type: ignore[operator] xid = self._xid self._submit(request, connect_timeout, xid) client._queue.popleft() - self._read_sock.recv(1) + # _read_sock should never be None here as we only call this method + # when we have set up the connection and the read socket. + self._read_sock.recv(1) # type: ignore[union-attr] client._pending.append((request, async_object, xid)) - def _send_ping(self, connect_timeout): + def _send_ping(self, connect_timeout: float) -> None: self.ping_outstanding.set() self._submit(PingInstance, connect_timeout, PING_XID) @@ -524,7 +633,7 @@ def _send_ping(self, connect_timeout): self._rw_server = result raise RWServerAvailable() - def zk_loop(self): + def zk_loop(self) -> None: """Main Zookeeper handling loop""" self.logger.log(BLATHER, "ZK loop started") @@ -546,16 +655,28 @@ def zk_loop(self): self.client._session_callback(KeeperState.CLOSED) self.logger.log(BLATHER, "Connection stopped") - def _expand_client_hosts(self): + def _expand_client_hosts(self) -> list[tuple[str, str, int]]: # Expand the entire list in advance so we can randomize it if needed - host_ports = [] + host_ports: list[tuple[str, str, int]] = [] for host, port in self.client.hosts: try: host = host.strip() for rhost in socket.getaddrinfo( host, port, 0, 0, socket.IPPROTO_TCP ): - host_ports.append((host, rhost[4][0], rhost[4][1])) + # FIXME These casts seem to be unnecessary on later + # versions of mypy/python + host_ports.append( + ( + host, + cast( # type: ignore[redundant-cast] + "str", rhost[4][0] + ), + cast( # type: ignore[redundant-cast] + "int", rhost[4][1] + ), + ) + ) except socket.gaierror as e: # Skip hosts that don't resolve self.logger.warning("Cannot resolve %s: %s", host, e) @@ -564,7 +685,7 @@ def _expand_client_hosts(self): random.shuffle(host_ports) return host_ports - def _connect_loop(self, retry): + def _connect_loop(self, retry: KazooRetry) -> object: # Iterate through the hosts a full cycle before starting over status = None host_ports = self._expand_client_hosts() @@ -586,7 +707,13 @@ def _connect_loop(self, retry): else: raise ForceRetryError("Reconnecting") - def _connect_attempt(self, host, hostip, port, retry): + def _connect_attempt( + self, + host: str, + hostip: str, + port: int, + retry: KazooRetry, + ) -> object: client = self.client KazooTimeoutError = self.handler.timeout_exception @@ -606,6 +733,9 @@ def _connect_attempt(self, host, hostip, port, retry): try: self._xid = 0 read_timeout, connect_timeout = self._connect(host, hostip, port) + # I think the above implies self._socket can't be none, and + # self._read_sock is set up in start but mypy can't tell that. + # Hence the casting. read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -619,7 +749,14 @@ def _connect_attempt(self, host, hostip, port, retry): # Ensure our timeout is positive timeout = max([deadline - time.monotonic(), jitter_time]) s = self.handler.select( - [self._socket, self._read_sock], [], [], timeout + [ + # FIXME we should know these aren't None + cast("Socket", self._socket), + cast("Socket", self._read_sock), + ], + [], + [], + timeout, )[0] if not s: @@ -629,14 +766,14 @@ def _connect_attempt(self, host, hostip, port, retry): "outstanding heartbeat ping not received" ) else: - if self._socket in s: + if cast("Socket", self._socket) in s: response = self._read_socket(read_timeout) if response == CLOSE_RESPONSE: break # Check if any requests need sending before proceeding # to process more responses. Otherwise the responses # may choke out the requests. See PR#633. - if self._read_sock in s: + if cast("Socket", self._read_sock) in s: self._send_request(read_timeout, connect_timeout) # Requests act as implicit pings. last_send = time.monotonic() @@ -674,9 +811,18 @@ def _connect_attempt(self, host, hostip, port, retry): raise finally: if self._socket is not None: - self._socket.close() - - def _connect(self, host, hostip, port): + # I think this is a bug in mypy, as the socket does get set up + # in self._connect, but it doesn't seem to be able to track + # that. + self._socket.close() # type: ignore[unreachable] + return None + + def _connect( + self, + host: str, + hostip: str, + port: int, + ) -> tuple[float, float]: client = self.client self.logger.info( "Connecting to %s(%s):%s, use_ssl: %r", @@ -707,7 +853,7 @@ def _connect(self, host, hostip, port): check_hostname=self.client.check_hostname, ) - self._socket.setblocking(0) + self._socket.setblocking(0) # type: ignore[arg-type] connect = Connect( 0, @@ -771,23 +917,36 @@ def _connect(self, host, hostip, port): return read_timeout, connect_timeout - def _authenticate_with_sasl(self, host, timeout): + def _authenticate_with_sasl(self, host: str, timeout: float) -> None: """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: raise SASLException("Missing SASL support") - if "service" not in self.sasl_options: - self.sasl_options["service"] = "zookeeper" + # Although this can only be called if sasl_options is not None, we + # really should just have make self.sasl_options into an empty dict + # in the constructor. However, I want to avoid code changes in as + # much as possible. + if "service" not in self.sasl_options: # type: ignore[operator] + self.sasl_options["service"] = "zookeeper" # type: ignore[index] # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options["mechanism"] == "DIGEST-MD5": + if ( + self.sasl_options["mechanism"] # type: ignore[index] + == "DIGEST-MD5" + ): host = "zk-sasl-md5" - sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, **self.sasl_options + # I don't think the client.sasl_cli attribute is actually used + # anywhere else, so not sure why we need to set it on the client, + # but again, I want to avoid code changes as much as possible. + sasl_cli = ( + self.client.sasl_cli # type: ignore[attr-defined] + ) = puresasl.client.SASLClient( # type: ignore[no-untyped-call] + host=host, + **self.sasl_options, # type: ignore[arg-type] ) # Initialize the process with an empty challenge token diff --git a/kazoo/protocol/paths.py b/kazoo/protocol/paths.py index b8bf6650..7c47ce8a 100644 --- a/kazoo/protocol/paths.py +++ b/kazoo/protocol/paths.py @@ -1,4 +1,4 @@ -def normpath(path, trailing=False): +def normpath(path: str, trailing: bool = False) -> str: """Normalize path, eliminating double slashes, etc.""" comps = path.split("/") new_comps = [] @@ -16,7 +16,7 @@ def normpath(path, trailing=False): return new_path -def join(a, *p): +def join(a: str, *p: str) -> str: """Join two or more pathname components, inserting '/' as needed. If any component is an absolute path, all previous path components @@ -34,23 +34,23 @@ def join(a, *p): return path -def isabs(s): +def isabs(s: str) -> bool: """Test whether a path is absolute""" return s.startswith("/") -def basename(p): +def basename(p: str) -> str: """Returns the final component of a pathname""" i = p.rfind("/") + 1 return p[i:] -def _prefix_root(root, path, trailing=False): +def _prefix_root(root: str, path: str, trailing: bool = False) -> str: """Prepend a root to a path.""" return normpath( join(_norm_root(root), path.lstrip("/")), trailing=trailing ) -def _norm_root(root): +def _norm_root(root: str) -> str: return normpath(join("/", root)) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c..29a55f35 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -1,12 +1,23 @@ -"""Zookeeper Serializers, Deserializers, and NamedTuple objects""" -from collections import namedtuple +"""Zookeeper Serializers, Deserializers, and namedtuple objects + +Note: On python3.8, you can't do classvars with NamedTuple. + +FIXME As soon as we get off python3.8 we should change the namedtuple objects +to NamedTuple, as it should get better typechecking. +""" +from __future__ import annotations + import struct +from collections import namedtuple +from typing import ClassVar, Sequence, Union, TYPE_CHECKING -from kazoo.exceptions import EXCEPTIONS +from kazoo.exceptions import EXCEPTIONS, ZookeeperError from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc # Struct objects with formats compiled bool_struct = struct.Struct("B") @@ -21,20 +32,24 @@ stat_struct = struct.Struct("!qqqqiiiqiiq") -def read_string(buffer, offset): +def read_string(buffer: bytes, offset: int) -> tuple[str, int]: """Reads an int specified buffer into a string and returns the string and the new offset in the buffer""" length = int_struct.unpack_from(buffer, offset)[0] offset += int_struct.size if length < 0: - return None, offset + # A note: write_str sends a length of -1 to indicate a value of None + # was passed. Not entirely sure where this happens because none of the + # callers of read_string seem to expect a None value. + # Should be ignoring return-value but hound cli... + return None, offset # type: ignore else: index = offset offset += length return buffer[index : index + length].decode("utf-8"), offset -def read_acl(bytes, offset): +def read_acl(bytes: bytes, offset: int) -> tuple[ACL, int]: perms = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size scheme, offset = read_string(bytes, offset) @@ -42,7 +57,7 @@ def read_acl(bytes, offset): return ACL(perms, Id(scheme, id)), offset -def write_string(bytes): +def write_string(bytes: str | None) -> bytes: if not bytes: return int_struct.pack(-1) else: @@ -50,14 +65,14 @@ def write_string(bytes): return int_struct.pack(len(utf8_str)) + utf8_str -def write_buffer(bytes): +def write_buffer(bytes: bytes | None) -> bytes: if bytes is None: return int_struct.pack(-1) else: return int_struct.pack(len(bytes)) + bytes -def read_buffer(bytes, offset): +def read_buffer(bytes: bytes, offset: int) -> tuple[bytes | None, int]: length = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if length < 0: @@ -69,10 +84,10 @@ def read_buffer(bytes, offset): class Close(namedtuple("Close", "")): - type = -11 + type: ClassVar[int] = -11 @classmethod - def serialize(cls): + def serialize(self) -> bytes: return b"" @@ -80,10 +95,10 @@ def serialize(cls): class Ping(namedtuple("Ping", "")): - type = 11 + type: ClassVar[int] = 11 @classmethod - def serialize(cls): + def serialize(cls) -> bytes: return b"" @@ -97,9 +112,16 @@ class Connect( " time_out session_id passwd read_only", ) ): - type = None + protocol_version: int + last_zxid_seen: int + time_out: int + session_id: int + passwd: bytes + read_only: bool - def serialize(self): + type: int | None = None # Note: Not a classvar + + def serialize(self) -> bytearray: b = bytearray() b.extend( int_long_int_long_struct.pack( @@ -114,7 +136,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Connect, int]: proto_version, timeout, session_id = int_int_long_struct.unpack_from( bytes, offset ) @@ -133,9 +155,14 @@ def deserialize(cls, bytes, offset): class Create(namedtuple("Create", "path data acl flags")): - type = 1 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int + + type: ClassVar[int] = 1 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -150,59 +177,74 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> str: return read_string(bytes, offset)[0] class Delete(namedtuple("Delete", "path version")): - type = 2 + path: str + version: int - def serialize(self): + type: ClassVar[int] = 2 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> bool: return True class Exists(namedtuple("Exists", "path watcher")): - type = 3 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 3 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat | None: + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return stat if stat.czxid != -1 else None class GetData(namedtuple("GetData", "path watcher")): - type = 4 + path: str + watcher: WatchFunc | None - def serialize(self): + type: ClassVar[int] = 4 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class SetData(namedtuple("SetData", "path data version")): - type = 5 + path: str + data: bytes | None + version: int + + type: ClassVar[int] = 5 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -210,18 +252,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetACL(namedtuple("GetACL", "path")): - type = 6 + path: str - def serialize(self): + type: ClassVar[int] = 6 + + def serialize(self) -> bytearray: return bytearray(write_string(self.path)) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[ACL], ZnodeStat] | list[ACL]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -231,14 +277,18 @@ def deserialize(cls, bytes, offset): for c in range(count): acl, offset = read_acl(bytes, offset) acls.append(acl) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return acls, stat class SetACL(namedtuple("SetACL", "path acls version")): - type = 7 + path: str + acls: Sequence[ACL] + version: int + + type: ClassVar[int] = 7 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(len(self.acls))) @@ -252,21 +302,24 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetChildren(namedtuple("GetChildren", "path watcher")): - type = 8 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 8 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -280,54 +333,71 @@ def deserialize(cls, bytes, offset): class Sync(namedtuple("Sync", "path")): - type = 9 + path: str - def serialize(self): + type: ClassVar[int] = 9 + + def serialize(self) -> bytes: return write_string(self.path) @classmethod - def deserialize(cls, buffer, offset): + def deserialize(cls, buffer: bytes, offset: int) -> str: return read_string(buffer, offset)[0] class GetChildren2(namedtuple("GetChildren2", "path watcher")): - type = 12 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 12 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[str], ZnodeStat] | list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover return [] - children = [] + children: list[str] = [] for c in range(count): child, offset = read_string(bytes, offset) children.append(child) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return children, stat class CheckVersion(namedtuple("CheckVersion", "path version")): - type = 13 + path: str + version: int + + type: ClassVar[int] = 13 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b +# FIXME Transaction class should move after Create2 +Transaction_Types = Union[Create, "Create2", Delete, SetData, CheckVersion] +Transaction_Response = Union[str, bool, ZnodeStat, ZookeeperError, None] + + class Transaction(namedtuple("Transaction", "operations")): - type = 14 + operations: list[Transaction_Types] + + type: ClassVar[int] = 14 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() for op in self.operations: b.extend( @@ -336,19 +406,19 @@ def serialize(self): return b + multiheader_struct.pack(-1, True, -1) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> list[Transaction_Response]: header = MultiHeader(None, False, None) - results = [] - response = None + results: list[Transaction_Response] = [] + response: Transaction_Response = None while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) elif header.type == Delete.type: response = True elif header.type == SetData.type: - response = ZnodeStat._make( - stat_struct.unpack_from(bytes, offset) - ) + response = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) offset += stat_struct.size elif header.type == CheckVersion.type: response = True @@ -362,8 +432,10 @@ def deserialize(cls, bytes, offset): return results @staticmethod - def unchroot(client, response): - resp = [] + def unchroot( + client: KazooClient, response: list[Transaction_Response] + ) -> list[Transaction_Response]: + resp: list[Transaction_Response] = [] for result in response: if isinstance(result, str): resp.append(client.unchroot(result)) @@ -373,9 +445,14 @@ def unchroot(client, response): class Create2(namedtuple("Create2", "path data acl flags")): - type = 15 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int - def serialize(self): + type: ClassVar[int] = 15 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -390,18 +467,23 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[str, ZnodeStat]: path, offset = read_string(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return path, stat class Reconfig( namedtuple("Reconfig", "joining leaving new_members config_id") ): - type = 16 + joining: str | None + leaving: str | None + new_members: str | None + config_id: int + + type: ClassVar[int] = 16 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.joining)) b.extend(write_string(self.leaving)) @@ -410,16 +492,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class Auth(namedtuple("Auth", "auth_type scheme auth")): - type = 100 + auth_type: int + scheme: str + auth: str - def serialize(self): + type: ClassVar[int] = 100 + + def serialize(self) -> bytes: return ( int_struct.pack(self.auth_type) + write_string(self.scheme) @@ -428,22 +516,30 @@ def serialize(self): class SASL(namedtuple("SASL", "challenge")): - type = 102 + challenge: bytes | None + + type: ClassVar[int] = 102 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_buffer(self.challenge)) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, int]: challenge, offset = read_buffer(bytes, offset) return challenge, offset class Watch(namedtuple("Watch", "type state path")): + type: int + state: int + path: str + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Watch, int]: """Given bytes and the current bytes offset, return the type, state, path, and new offset""" type, state = int_int_struct.unpack_from(bytes, offset) @@ -453,19 +549,27 @@ def deserialize(cls, bytes, offset): class ReplyHeader(namedtuple("ReplyHeader", "xid, zxid, err")): + xid: int + zxid: int + err: int + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[ReplyHeader, int]: """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size return ( - cls._make(reply_header_struct.unpack_from(bytes, offset)), + cls(*reply_header_struct.unpack_from(bytes, offset)), new_offset, ) -class MultiHeader(namedtuple("MultiHeader", "type done err")): - def serialize(self): +class MultiHeader(namedtuple("MultiHeader", "type, done, err")): + type: int | None + done: bool + err: int | None + + def serialize(self) -> bytearray: b = bytearray() b.extend(int_struct.pack(self.type)) b.extend([1 if self.done else 0]) @@ -473,7 +577,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[MultiHeader, int]: t, done, err = multiheader_struct.unpack_from(bytes, offset) offset += multiheader_struct.size return cls(t, done == 1, err), offset diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e..ad835266 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -1,8 +1,13 @@ """Kazoo State and Event objects""" -from collections import namedtuple +from __future__ import annotations -class KazooState(object): +from enum import Enum +from typing import Any, Callable, NamedTuple + + +# This is a (str, Enum) for backwards compatibility. +class KazooState(str, Enum): """High level connection state values States inspired by Netflix Curator. @@ -33,7 +38,8 @@ class KazooState(object): LOST = "LOST" -class KeeperState(object): +# This is a (str, Enum) for backwards compatibility. +class KeeperState(str, Enum): """Zookeeper State Represents the Zookeeper state. Watch functions will receive a @@ -70,7 +76,8 @@ class KeeperState(object): EXPIRED_SESSION = "EXPIRED_SESSION" -class EventType(object): +# This is a (str, Enum) for backwards compatibility. +class EventType(str, Enum): """Zookeeper Event Represents a Zookeeper event. Events trigger watch functions which @@ -117,7 +124,7 @@ class EventType(object): } -class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): +class WatchedEvent(NamedTuple): """A change on ZooKeeper that a Watcher is able to respond to. The :class:`WatchedEvent` includes exactly what happened, the @@ -140,8 +147,12 @@ class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): """ + type: EventType + state: KeeperState + path: str | None + -class Callback(namedtuple("Callback", ("type", "func", "args"))): +class Callback(NamedTuple): """A callback that is handed to a handler for dispatch :param type: Type of the callback, currently is only 'watch' @@ -150,15 +161,12 @@ class Callback(namedtuple("Callback", ("type", "func", "args"))): """ + type: str + func: Callable[..., Any] + args: tuple[Any, ...] -class ZnodeStat( - namedtuple( - "ZnodeStat", - "czxid mzxid ctime mtime version" - " cversion aversion ephemeralOwner dataLength" - " numChildren pzxid", - ) -): + +class ZnodeStat(NamedTuple): """A ZnodeStat structure with convenience properties When getting the value of a znode from Zookeeper, the properties for @@ -216,38 +224,50 @@ class ZnodeStat( """ + czxid: int + mzxid: int + ctime: int + mtime: int + version: int + cversion: int + aversion: int + ephemeralOwner: int + dataLength: int + numChildren: int + pzxid: int + @property - def acl_version(self): + def acl_version(self) -> int: return self.aversion @property - def children_version(self): + def children_version(self) -> int: return self.cversion @property - def created(self): + def created(self) -> float: return self.ctime / 1000.0 @property - def last_modified(self): + def last_modified(self) -> float: return self.mtime / 1000.0 @property - def owner_session_id(self): + def owner_session_id(self) -> int | None: return self.ephemeralOwner or None @property - def creation_transaction_id(self): + def creation_transaction_id(self) -> int: return self.czxid @property - def last_modified_transaction_id(self): + def last_modified_transaction_id(self) -> int: return self.mzxid @property - def data_length(self): + def data_length(self) -> int: return self.dataLength @property - def children_count(self): + def children_count(self) -> int: return self.numChildren diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index 683e807b..26ffc935 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -4,12 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + import os import socket import uuid +from typing import Literal, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent + +if TYPE_CHECKING: + from kazoo.client import KazooClient class Barrier(object): @@ -27,7 +34,7 @@ class Barrier(object): """ - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """Create a Kazoo Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -37,11 +44,11 @@ def __init__(self, client, path): self.client = client self.path = path - def create(self): + def create(self) -> None: """Establish the barrier if it doesn't exist already""" self.client.retry(self.client.ensure_path, self.path) - def remove(self): + def remove(self) -> bool: """Remove the barrier :returns: Whether the barrier actually needed to be removed. @@ -54,7 +61,7 @@ def remove(self): except NoNodeError: return False - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Wait on the barrier to be cleared :returns: True if the barrier has been cleared, otherwise @@ -64,7 +71,7 @@ def wait(self, timeout=None): """ cleared = self.client.handler.event_object() - def wait_for_clear(event): + def wait_for_clear(event: WatchedEvent) -> None: if event.type == EventType.DELETED: cleared.set() @@ -93,7 +100,13 @@ class DoubleBarrier(object): """ - def __init__(self, client, path, num_clients, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + num_clients: int, + identifier: str | None = None, + ): """Create a Double Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -118,7 +131,7 @@ def __init__(self, client, path, num_clients, identifier=None): self.node_name = uuid.uuid4().hex self.create_path = self.path + "/" + self.node_name - def enter(self): + def enter(self) -> None: """Enter the barrier, blocks until all nodes have entered""" try: self.client.retry(self._inner_enter) @@ -128,7 +141,7 @@ def enter(self): self._best_effort_cleanup() self.participating = False - def _inner_enter(self): + def _inner_enter(self) -> Literal[True]: # make sure our barrier parent node exists if not self.assured_path: self.client.ensure_path(self.path) @@ -145,7 +158,7 @@ def _inner_enter(self): except NodeExistsError: pass - def created(event): + def created(event: WatchedEvent) -> None: if event.type == EventType.CREATED: ready.set() @@ -159,7 +172,7 @@ def created(event): self.client.ensure_path(self.path + "/ready") return True - def leave(self): + def leave(self) -> None: """Leave the barrier, blocks until all nodes have left""" try: self.client.retry(self._inner_leave) @@ -168,7 +181,7 @@ def leave(self): self._best_effort_cleanup() self.participating = False - def _inner_leave(self): + def _inner_leave(self) -> bool: # Delete the ready node if its around try: self.client.delete(self.path + "/ready") @@ -188,7 +201,7 @@ def _inner_leave(self): ready = self.client.handler.event_object() - def deleted(event): + def deleted(event: WatchedEvent) -> None: if event.type == EventType.DELETED: ready.set() @@ -214,7 +227,7 @@ def deleted(event): # Wait for the lowest to be deleted ready.wait() - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.retry(self.client.delete, self.create_path) except NoNodeError: diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 0a22a6c7..cee96c82 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -10,20 +10,42 @@ See also: http://curator.apache.org/curator-recipes/tree-cache.html """ + +from __future__ import annotations from __future__ import absolute_import import contextlib import functools import logging import operator +from typing import ( + Any, + Callable, + Generator, + Protocol, + TypeVar, + Tuple, + TYPE_CHECKING, + Union, + overload, +) + from kazoo.exceptions import NoNodeError, KazooException from kazoo.protocol.paths import _prefix_root, join as kazoo_join -from kazoo.protocol.states import KazooState, EventType +from kazoo.protocol.states import KazooState, EventType, ZnodeStat + +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import IAsyncResult, Threadlike + from kazoo.protocol.states import WatchedEvent logger = logging.getLogger(__name__) +ReturnValue = TypeVar("ReturnValue") + + class TreeCache(object): """The cache of a ZooKeeper subtree. @@ -37,18 +59,18 @@ class TreeCache(object): _STOP = object() - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): self._client = client self._root = TreeNode.make_root(self, path) self._state = self.STATE_LATENT self._outstanding_ops = 0 self._is_initialized = False - self._error_listeners = [] - self._event_listeners = [] + self._error_listeners: list[Callable[[Exception], None]] = [] + self._event_listeners: list[Callable[[TreeEvent], None]] = [] self._task_queue = client.handler.queue_impl() - self._task_thread = None + self._task_thread: Threadlike | None = None - def start(self): + def start(self) -> None: """Starts the cache. The cache is not started automatically. You must call this method. @@ -85,7 +107,7 @@ def start(self): # without lock. self._in_background(self._root.on_created) - def close(self): + def close(self) -> None: """Closes the cache. A closed cache was detached from ZooKeeper's changes. And all nodes @@ -109,7 +131,9 @@ def close(self): # ZooKeeper actually. self._root.on_deleted() - def listen(self, listener): + def listen( + self, listener: Callable[[TreeEvent], None] + ) -> Callable[[TreeEvent], None]: """Registers a function to listen the cache events. The cache events are changes of local data. They are delivered from @@ -124,7 +148,9 @@ def listen(self, listener): self._event_listeners.append(listener) return listener - def listen_fault(self, listener): + def listen_fault( + self, listener: Callable[[Exception], None] + ) -> Callable[[Exception], None]: """Registers a function to listen the exceptions. It is possible to meet some exceptions during the cache running. You @@ -138,7 +164,9 @@ def listen_fault(self, listener): self._error_listeners.append(listener) return listener - def get_data(self, path, default=None): + def get_data( + self, path: str, default: NodeData | None = None + ) -> NodeData | None: """Gets data of a node from cache. :param path: The absolute path string. @@ -150,7 +178,9 @@ def get_data(self, path, default=None): node = self._find_node(path) return default if node is None else node._data - def get_children(self, path, default=None): + def get_children( + self, path: str, default: frozenset[str] | None = None + ) -> frozenset[str] | None: """Gets node children list from in-memory snapshot. :param path: The absolute path string. @@ -158,11 +188,14 @@ def get_children(self, path, default=None): does not exist. :raises ValueError: If the path is outside of this subtree. :returns: The :class:`frozenset` which including children names. + + # FIXME the default return value should be an empty frozenset, + # returning None is confusing. """ node = self._find_node(path) return default if node is None else frozenset(node._children) - def _find_node(self, path): + def _find_node(self, path: str) -> TreeNode | None: if not path.startswith(self._root._path): raise ValueError("outside of tree") striped_path = path[len(self._root._path) :].strip("/") @@ -170,25 +203,49 @@ def _find_node(self, path): current_node = self._root for node_name in splited_path: if node_name not in current_node._children: - return + return None current_node = current_node._children[node_name] return current_node - def _publish_event(self, event_type, event_data=None): + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: logger.debug("public event: %r", event) self._in_background(self._do_publish_event, event) - def _do_publish_event(self, event): + def _do_publish_event(self, event: TreeEvent) -> None: for listener in self._event_listeners: with handle_exception(self._error_listeners): listener(event) - def _in_background(self, func, *args, **kwargs): + @overload + def _in_background( + self, func: Callable[[TreeEvent], None], event: TreeEvent + ) -> None: + ... + + @overload + def _in_background(self, func: Callable[[], None]) -> None: + ... + + @overload + def _in_background( + self, + func: Callable[[str, str, IAsyncResult], None], + method_name: str, + path: str, + result: IAsyncResult, + ) -> None: + ... + + def _in_background( # type: ignore[misc] + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: self._task_queue.put((func, args, kwargs)) - def _do_background(self): + def _do_background(self) -> None: while True: with handle_exception(self._error_listeners): cb = self._task_queue.get() @@ -200,7 +257,7 @@ def _do_background(self): # release before possible idle del cb, func, args, kwargs - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.SUSPENDED: self._publish_event(TreeEvent.CONNECTION_SUSPENDED) elif state == KazooState.CONNECTED: @@ -212,6 +269,11 @@ def _session_watcher(self, state): self._publish_event(TreeEvent.CONNECTION_LOST) +class AsyncWatcher(Protocol): + def __call__(self, path: str, watch: WatchFunc | None) -> IAsyncResult: + ... + + class TreeNode(object): """The tree node record. @@ -234,28 +296,28 @@ class TreeNode(object): STATE_LIVE = 1 STATE_DEAD = 2 - def __init__(self, tree, path, parent): + def __init__(self, tree: TreeCache, path: str, parent: TreeNode | None): self._tree = tree self._path = path self._parent = parent - self._depth = parent._depth + 1 if parent else 0 - self._children = {} + self._depth: int = parent._depth + 1 if parent is not None else 0 + self._children: dict[str, TreeNode] = {} self._state = self.STATE_PENDING - self._data = None + self._data: NodeData | None = None @classmethod - def make_root(cls, tree, path): + def make_root(cls, tree: TreeCache, path: str) -> TreeNode: return cls(tree, path, None) - def on_reconnected(self): + def on_reconnected(self) -> None: self._refresh() for child in self._children.values(): child.on_reconnected() - def on_created(self): + def on_created(self) -> None: self._refresh() - def on_deleted(self): + def on_deleted(self) -> None: old_children, self._children = self._children, {} old_data, self._data = self._data, None @@ -278,37 +340,43 @@ def on_deleted(self): del self._parent._children[child] self._reset_watchers() - def _publish_event(self, *args, **kwargs): - return self._tree._publish_event(*args, **kwargs) + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: + return self._tree._publish_event(event_type, event_data) - def _reset_watchers(self): + def _reset_watchers(self) -> None: client = self._tree._client for _watchers in (client._data_watchers, client._child_watchers): _path = _prefix_root(client.chroot, self._path) _watcher = _watchers.get(_path, set()) _watcher.discard(self._process_watch) - def _refresh(self): + def _refresh(self) -> None: self._refresh_data() self._refresh_children() - def _refresh_data(self): + def _refresh_data(self) -> None: self._call_client("get", self._path) - def _refresh_children(self): + def _refresh_children(self) -> None: # TODO max-depth checking support self._call_client("get_children", self._path) - def _call_client(self, method_name, path): + def _call_client(self, method_name: str, path: str) -> None: assert method_name in ("get", "get_children", "exists") self._tree._outstanding_ops += 1 callback = functools.partial( self._tree._in_background, self._process_result, method_name, path ) - method = getattr(self._tree._client, method_name + "_async") + # The typing for this is really bad but the type checker can + # understand it with a few hacks + method: AsyncWatcher = getattr( + self._tree._client, method_name + "_async" + ) method(path, watch=self._process_watch).rawlink(callback) - def _process_watch(self, watched_event): + def _process_watch(self, watched_event: WatchedEvent) -> None: logger.debug("process_watch: %r", watched_event) with handle_exception(self._tree._error_listeners): if watched_event.type == EventType.CREATED: @@ -321,7 +389,9 @@ def _process_watch(self, watched_event): elif watched_event.type == EventType.CHILD: self._refresh_children() - def _process_result(self, method_name, path, result): + def _process_result( + self, method_name: str, path: str, result: IAsyncResult + ) -> None: logger.debug("process_result: %s %s", method_name, path) if method_name == "exists": assert self._parent is None, "unexpected EXISTS on non-root" @@ -332,7 +402,7 @@ def _process_result(self, method_name, path, result): self.on_created() elif method_name == "get_children": if result.successful(): - children = result.get() + children: list[str] = result.get() for child in sorted(children): full_path = kazoo_join(path, child) if child not in self._children: @@ -367,9 +437,12 @@ def _process_result(self, method_name, path, result): self._publish_event(TreeEvent.INITIALIZED) -class TreeEvent(tuple): +# FIXME these Tuple-based classes would look a lot better using NamedTuple +# though the event_type in TreeEvent needs sorting. +class TreeEvent(Tuple[int, Union["NodeData", None]]): """The immutable event tuple of cache.""" + # FIXME These should be an enum. NODE_ADDED = 0 NODE_UPDATED = 1 NODE_REMOVED = 2 @@ -385,7 +458,9 @@ class TreeEvent(tuple): event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type, event_data): + def make( + cls, event_type: int, event_data: NodeData | None = None + ) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. @@ -402,7 +477,7 @@ def make(cls, event_type, event_data): return cls((event_type, event_data)) -class NodeData(tuple): +class NodeData(Tuple[str, bytes, ZnodeStat]): """The immutable node data tuple of cache.""" #: The absolute path string of current node. @@ -415,7 +490,7 @@ class NodeData(tuple): stat = property(operator.itemgetter(2)) @classmethod - def make(cls, path, data, stat): + def make(cls, path: str, data: bytes, stat: ZnodeStat) -> NodeData: """Creates a new NodeData tuple. :returns: A :class:`~kazoo.recipe.cache.NodeData` instance. @@ -424,7 +499,9 @@ def make(cls, path, data, stat): @contextlib.contextmanager -def handle_exception(listeners): +def handle_exception( + listeners: list[Callable[[Exception], None]], +) -> Generator[None, None, None]: try: yield except Exception as e: diff --git a/kazoo/recipe/counter.py b/kazoo/recipe/counter.py index 3b2cc339..53132252 100644 --- a/kazoo/recipe/counter.py +++ b/kazoo/recipe/counter.py @@ -4,9 +4,20 @@ :Status: Unknown """ + +from __future__ import annotations + +import struct +from typing import Union, TYPE_CHECKING + from kazoo.exceptions import BadVersionError from kazoo.retry import ForceRetryError -import struct + +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +Number = Union[int, float] class Counter(object): @@ -58,7 +69,13 @@ class Counter(object): """ - def __init__(self, client, path, default=0, support_curator=False): + def __init__( + self, + client: KazooClient, + path: str, + default: Number = 0, + support_curator: bool = False, + ): """Create a Kazoo Counter :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -74,22 +91,24 @@ def __init__(self, client, path, default=0, support_curator=False): self.default_type = type(default) self.support_curator = support_curator self._ensured_path = False - self.pre_value = None - self.post_value = None + self.pre_value: Number | None = None + self.post_value: Number | None = None if self.support_curator and not isinstance(self.default, int): raise TypeError( "when support_curator is enabled the default " "type must be an int" ) - def _ensure_node(self): + def _ensure_node(self) -> None: if not self._ensured_path: # make sure our node exists self.client.ensure_path(self.path) self._ensured_path = True - def _value(self): + def _value(self) -> tuple[Number, int]: self._ensure_node() + # This is astonishingly hard to follow... + old: Union[bytes, str, Number] old, stat = self.client.get(self.path) if self.support_curator: old = struct.unpack(">i", old)[0] if old != b"" else self.default @@ -100,16 +119,16 @@ def _value(self): return data, version @property - def value(self): + def value(self) -> Number: return self._value()[0] - def _change(self, value): + def _change(self, value: Number) -> Counter: if not isinstance(value, self.default_type): raise TypeError("invalid type for value change") self.client.retry(self._inner_change, value) return self - def _inner_change(self, value): + def _inner_change(self, value: Number) -> None: self.pre_value, version = self._value() post_value = self.pre_value + value if self.support_curator: @@ -123,10 +142,10 @@ def _inner_change(self, value): raise ForceRetryError() self.post_value = post_value - def __add__(self, value): + def __add__(self, value: Number) -> Counter: """Add value to counter.""" return self._change(value) - def __sub__(self, value): + def __sub__(self, value: Number) -> Counter: """Subtract value from counter.""" return self._change(-value) diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 93bb7258..1e28517b 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -4,8 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING +from typing_extensions import ParamSpec + from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + +GenericArgs = ParamSpec("GenericArgs") + class Election(object): """Kazoo Basic Leader Election @@ -22,7 +33,12 @@ class Election(object): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """Create a Kazoo Leader Election :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -34,7 +50,12 @@ def __init__(self, client, path, identifier=None): """ self.lock = client.Lock(path, identifier) - def run(self, func, *args, **kwargs): + def run( + self, + func: Callable[GenericArgs, None], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> None: """Contend for the leadership This call will block until either this contender is cancelled @@ -57,7 +78,7 @@ def run(self, func, *args, **kwargs): except CancelledError: pass - def cancel(self): + def cancel(self) -> None: """Cancel participation in the election .. note:: @@ -68,7 +89,7 @@ def cancel(self): """ self.lock.cancel() - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders in the election diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index ce7fe567..1e4c9cf8 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -5,12 +5,25 @@ :Status: Beta """ + +from __future__ import annotations + import datetime import json import socket +from typing import Callable, TypedDict, TYPE_CHECKING, cast from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +class Lease(TypedDict): + version: int + holder: str + end: str + class NonBlockingLease(object): """Exclusive lease that does not block. @@ -48,11 +61,11 @@ class NonBlockingLease(object): def __init__( self, - client, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): """Create a non-blocking lease. @@ -71,7 +84,14 @@ def __init__( self.obtained = False self._attempt_obtaining(client, path, duration, ident, utcnow) - def _attempt_obtaining(self, client, path, duration, ident, utcnow): + def _attempt_obtaining( + self, + client: KazooClient, + path: str, + duration: datetime.timedelta, + ident: str, + utcnow: Callable[[], datetime.datetime], + ) -> None: client.ensure_path(path) holder_path = path + "/lease_holder" lock = client.Lock(path, ident) @@ -92,7 +112,7 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): return client.delete(holder_path) end_lease = (now + duration).strftime(self._date_format) - new_data = { + new_data: Lease = { "version": self._version, "holder": ident, "end": end_lease, @@ -103,18 +123,13 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): except CancelledError: pass - def _encode(self, data_dict): + def _encode(self, data_dict: Lease) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) - def _decode(self, raw): - return json.loads(raw.decode(self._byte_encoding)) - - # Python 2.x - def __nonzero__(self): - return self.obtained + def _decode(self, raw: bytes) -> Lease: + return cast("Lease", json.loads(raw.decode(self._byte_encoding))) - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained @@ -140,12 +155,12 @@ class MultiNonBlockingLease(object): def __init__( self, - client, - count, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + count: int, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): self.obtained = False for num in range(count): @@ -160,10 +175,5 @@ def __init__( self.obtained = True break - # Python 2.x - def __nonzero__(self): - return self.obtained - - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 1f524702..6f3236db 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -14,9 +14,19 @@ and/or the lease has been lost. """ + +from __future__ import annotations + import re import time import uuid +from typing import ( + Iterable, + Literal, + Pattern, + TYPE_CHECKING, +) +from types import TracebackType from kazoo.exceptions import ( CancelledError, @@ -24,27 +34,37 @@ LockTimeout, NoNodeError, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent from kazoo.retry import ( ForceRetryError, KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient + class _Watch(object): - def __init__(self, duration=None): + def __init__(self, duration: float | None = None): self.duration = duration - self.started_at = None + self.started_at: float | None = None - def start(self): + def start(self) -> None: self.started_at = time.monotonic() - def leftover(self): + def leftover(self) -> float | None: if self.duration is None: return None else: - elapsed = time.monotonic() - self.started_at + # We should probably set started_at to either 0 or + # time.monotonic() in __init__ to avoid the type ignore + # here, but this is a private class and it's pretty clear + # that start() should be called before leftover() so I'm + # not sure it's worth it. + elapsed = ( + time.monotonic() - self.started_at # type: ignore[operator] + ) return max(0, self.duration - elapsed) @@ -77,7 +97,13 @@ class Lock(object): # sequence number. Involved in read/write locks. _EXCLUDE_NAMES = ["__lock__"] - def __init__(self, client, path, identifier=None, extra_lock_patterns=()): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + extra_lock_patterns: Iterable[str] = (), + ): """Create a Kazoo lock. :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -97,10 +123,10 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): """ self.client = client self.path = path - self._exclude_names = set( + self._exclude_names: set[str] = set( self._EXCLUDE_NAMES + list(extra_lock_patterns) ) - self._contenders_re = re.compile( + self._contenders_re: Pattern[str] = re.compile( r"(?:{patterns})(-?\d{{10}})$".format( patterns="|".join(self._exclude_names) ) @@ -109,7 +135,7 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock self.data = str(identifier or "").encode("utf-8") - self.node = None + self.node: str | None = None self.wake_event = client.handler.event_object() @@ -129,16 +155,21 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): ) self._acquire_method_lock = client.handler.lock_object() - def _ensure_path(self): + def _ensure_path(self) -> None: self.client.ensure_path(self.path) self.assured_path = True - def cancel(self): + def cancel(self) -> None: """Cancel a pending lock acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None, ephemeral=True): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ephemeral: bool = True, + ) -> bool: """ Acquire the lock. By defaults blocks and waits forever. @@ -204,11 +235,16 @@ def acquire(self, blocking=True, timeout=None, ephemeral=True): finally: self._acquire_method_lock.release() - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool: self.wake_event.set() return True - def _inner_acquire(self, blocking, timeout, ephemeral=True): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None, + ephemeral: bool = True, + ) -> bool: # wait until it's our chance to get it.. if self.is_acquired: if not blocking: @@ -219,7 +255,7 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): if not self.assured_path: self._ensure_path() - node = None + node: str | None = None if self.create_tried: node = self._find_node() else: @@ -265,10 +301,10 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): finally: self.client.remove_listener(self._watch_session) - def _watch_predecessor(self, event): + def _watch_predecessor(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_predecessor(self, node): + def _get_predecessor(self, node: str) -> str | None: """returns `node`'s predecessor or None Note: This handle the case where the current lock is not a contender @@ -277,7 +313,7 @@ def _get_predecessor(self, node): """ node_sequence = node[len(self.prefix) :] children = self.client.get_children(self.path) - found_self = False + found_self: Literal[False] | re.Match[str] | None = False # Filter out the contenders using the computed regex contender_matches = [] for child in children: @@ -308,17 +344,17 @@ def _get_predecessor(self, node): sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) return sorted_matches[-1].string - def _find_node(self): + def _find_node(self) -> str | None: children = self.client.get_children(self.path) for child in children: if child.startswith(self.prefix): return child return None - def _delete_node(self, node): + def _delete_node(self, node: str) -> None: self.client.delete(self.path + "/" + node) - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: node = self.node or self._find_node() if node: @@ -326,16 +362,18 @@ def _best_effort_cleanup(self): except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lock immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: - self._delete_node(self.node) + # I don't think it's possible for self.node to be None here if + # self.is_acquired is true. + self._delete_node(self.node) # type: ignore[arg-type] except NoNodeError: # pragma: nocover pass @@ -343,7 +381,7 @@ def _inner_release(self): self.node = None return True - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders for the lock. @@ -390,11 +428,17 @@ def contenders(self): return contenders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: self.release() + return None class WriteLock(Lock): @@ -492,7 +536,13 @@ class Semaphore(object): """ - def __init__(self, client, path, identifier=None, max_leases=1): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + max_leases: int = 1, + ): """Create a Kazoo Lock :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -528,7 +578,7 @@ def __init__(self, client, path, identifier=None, max_leases=1): self.cancelled = False self._session_expired = False - def _ensure_path(self): + def _ensure_path(self) -> None: result = self.client.ensure_path(self.path) self.assured_path = True if result is True: @@ -549,12 +599,16 @@ def _ensure_path(self): else: self.client.set(self.path, str(self.max_leases).encode("utf-8")) - def cancel(self): + def cancel(self) -> None: """Cancel a pending semaphore acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ) -> bool: """Acquire the semaphore. By defaults blocks and waits forever. :param blocking: Block until semaphore is obtained or @@ -592,7 +646,11 @@ def acquire(self, blocking=True, timeout=None): return self.is_acquired - def _inner_acquire(self, blocking, timeout=None): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None = None, + ) -> bool: """Inner loop that runs from the top anytime a command hits a retryable Zookeeper exception.""" self._session_expired = False @@ -607,7 +665,12 @@ def _inner_acquire(self, blocking, timeout=None): w = _Watch(duration=timeout) w.start() - lock = self.client.Lock(self.lock_path, self.data) + # FIXME This is passing bytes data, but self.client.Lock expects a str, + # which I think is a bug in this code. However, I don't want to + # change any code at this point, so we just ignore the type error here. + lock = self.client.Lock( + self.lock_path, self.data # type: ignore[arg-type] + ) try: gotten = lock.acquire(blocking=blocking, timeout=w.leftover()) if not gotten: @@ -633,10 +696,10 @@ def _inner_acquire(self, blocking, timeout=None): finally: lock.release() - def _watch_lease_change(self, event): + def _watch_lease_change(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_lease(self, data=None): + def _get_lease(self) -> bool: # Make sure the session is still valid if self._session_expired: raise ForceRetryError("Retry on session loss at top") @@ -665,25 +728,26 @@ def _get_lease(self, data=None): # Return current state return self.is_acquired - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool | None: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() # Return true to de-register return True + return None - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.delete(self.create_path) except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lease immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: @@ -693,7 +757,7 @@ def _inner_release(self): self.is_acquired = False return True - def lease_holders(self): + def lease_holders(self) -> list[str]: """Return an unordered list of the current lease holders. .. note:: @@ -716,8 +780,13 @@ def lease_holders(self): pass return lease_holders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.release() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 21dc6ef4..19c48c28 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -17,20 +17,30 @@ so that no two workers own the same queue. """ + +from __future__ import annotations + from functools import partial import logging import os import socket +from enum import Enum +from typing import Callable, Iterator, Sequence, TYPE_CHECKING from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState from kazoo.recipe.watchers import PatientChildrenWatch +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import Event, IAsyncResult + from kazoo.recipe.lock import Lock log = logging.getLogger(__name__) -class PartitionState(object): +# This is a (str, Enum) for backwards compatibility. +class PartitionState(str, Enum): """High level partition state values .. attribute:: ALLOCATING @@ -139,14 +149,15 @@ class SetPartitioner(object): def __init__( self, - client, - path, - set, - partition_func=None, - identifier=None, - time_boundary=30, - max_reaction_time=1, - state_change_event=None, + client: KazooClient, + path: str, + set: Sequence[str], + partition_func: Callable[[str, list[str], Sequence[str]], list[str]] + | None = None, + identifier: str | None = None, + time_boundary: float = 30, + max_reaction_time: float = 1, + state_change_event: Event | None = None, ): """Create a :class:`~SetPartitioner` instance @@ -176,13 +187,13 @@ def __init__( self._client = client self._path = path self._set = set - self._partition_set = [] + self._partition_set: list[str] = [] self._partition_func = partition_func or self._partitioner self._identifier = identifier or "%s-%s" % ( socket.getfqdn(), os.getpid(), ) - self._locks = [] + self._locks: list[Lock] = [] self._lock_path = "/".join([path, "locks"]) self._party_path = "/".join([path, "party"]) self._time_boundary = time_boundary @@ -208,33 +219,33 @@ def __init__( # so we know when we're ready self._child_watching(self._allocate_transition, client_handler=True) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Return the partitions in this partition set""" for partition in self._partition_set: yield partition @property - def failed(self): + def failed(self) -> bool: """Corresponds to the :attr:`PartitionState.FAILURE` state""" return self.state == PartitionState.FAILURE @property - def release(self): + def release(self) -> bool: """Corresponds to the :attr:`PartitionState.RELEASE` state""" return self.state == PartitionState.RELEASE @property - def allocating(self): + def allocating(self) -> bool: """Corresponds to the :attr:`PartitionState.ALLOCATING` state""" return self.state == PartitionState.ALLOCATING @property - def acquired(self): + def acquired(self) -> bool: """Corresponds to the :attr:`PartitionState.ACQUIRED` state""" return self.state == PartitionState.ACQUIRED - def wait_for_acquire(self, timeout=30): + def wait_for_acquire(self, timeout: float = 30) -> None: """Wait for the set to be partitioned and acquired :param timeout: How long to wait before returning. @@ -243,7 +254,7 @@ def wait_for_acquire(self, timeout=30): """ self._acquire_event.wait(timeout) - def release_set(self): + def release_set(self) -> None: """Call to release the set This method begins the step of allocating once the set has @@ -263,12 +274,12 @@ def release_set(self): self._set_state(PartitionState.ALLOCATING) self._child_watching(self._allocate_transition, client_handler=True) - def finish(self): + def finish(self) -> None: """Call to release the set and leave the party""" self._release_locks() self._fail_out() - def _fail_out(self): + def _fail_out(self) -> None: with self._state_change: self._set_state(PartitionState.FAILURE) if self._party.participating: @@ -277,7 +288,7 @@ def _fail_out(self): except KazooException: # pragma: nocover pass - def _allocate_transition(self, result): + def _allocate_transition(self, result: IAsyncResult) -> None: """Called when in allocating mode, and the children settled""" # Did we get an exception waiting for children to settle? @@ -288,7 +299,7 @@ def _allocate_transition(self, result): children, async_result = result.get() children_changed = self._client.handler.event_object() - def updated(result): + def updated(result: IAsyncResult) -> None: with self._state_change: children_changed.set() if self.acquired: @@ -307,7 +318,7 @@ def updated(result): # Check whether the state has changed during the lock acquisition # and abort the process if so. - def abort_if_needed(): + def abort_if_needed() -> bool: if self.state_id == state_id: if children_changed.is_set(): # The party has changed. Repartitioning... @@ -365,7 +376,7 @@ def abort_if_needed(): # This mustn't happen. Means a logical error. self._fail_out() - def _release_locks(self): + def _release_locks(self) -> None: """Attempt to completely remove all the locks""" self._acquire_event.clear() for lock in self._locks[:]: @@ -378,7 +389,7 @@ def _release_locks(self): else: self._locks.remove(lock) - def _abort_lock_acquisition(self): + def _abort_lock_acquisition(self) -> None: """Called during lock acquisition if a party change occurs""" self._release_locks() @@ -391,7 +402,13 @@ def _abort_lock_acquisition(self): self._child_watching(self._allocate_transition, client_handler=True) - def _child_watching(self, func=None, client_handler=False): + # FIXME This is only ever called with func=self._allocation_transition, but + # I didn't want to change the code. + def _child_watching( + self, + func: Callable[[IAsyncResult], None] | None = None, + client_handler: bool = False, + ) -> IAsyncResult: """Called when children are being watched to stabilize This actually returns immediately, child watcher spins up a @@ -410,11 +427,15 @@ def _child_watching(self, func=None, client_handler=False): # to ensure that the rawlink's it might use won't be # blocked if client_handler: - func = partial(self._client.handler.spawn, func) + # FIXME This feels wrong, but it may be because partial is + # confusing things. + func = partial( # type: ignore[assignment] + self._client.handler.spawn, func + ) asy.rawlink(func) return asy - def _establish_sessionwatch(self, state): + def _establish_sessionwatch(self, state: KazooState) -> bool: """Register ourself to listen for session events, we shut down if we become lost""" with self._state_change: @@ -427,7 +448,12 @@ def _establish_sessionwatch(self, state): return state == KazooState.LOST - def _partitioner(self, identifier, members, partitions): + def _partitioner( + self, + identifier: str, + members: list[str], + partitions: Sequence[str], + ) -> list[str]: # Ensure consistent order of partitions/members all_partitions = sorted(partitions) workers = sorted(members) @@ -437,7 +463,7 @@ def _partitioner(self, identifier, members, partitions): # skipping the other workers return all_partitions[i :: len(workers)] - def _set_state(self, state): + def _set_state(self, state: PartitionState) -> None: self.state = state self.state_id += 1 self.state_change_event.set() diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index 2a0f5dfb..1fc1340b 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -7,15 +7,27 @@ used for determining members of a party. """ + +from __future__ import annotations + import uuid +from typing import Iterator, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseParty(object): """Base implementation of a party.""" - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The party path to use. @@ -29,44 +41,52 @@ def __init__(self, client, path, identifier=None): self.ensured_path = False self.participating = False - def _ensure_parent(self): + def _ensure_parent(self) -> None: if not self.ensured_path: # make sure our parent node exists self.client.ensure_path(self.path) self.ensured_path = True - def join(self): + def join(self) -> None: """Join the party""" return self.client.retry(self._inner_join) - def _inner_join(self): + def _inner_join(self) -> None: self._ensure_parent() try: - self.client.create(self.create_path, self.data, ephemeral=True) + # This and the #type: ignore[attr-defined] below could be removed + # by setting up create_path in the constructor but trying to avoid + # changing the code too much. It does actually cause later versions + # of pylint to error though. + self.client.create( + self.create_path, # type: ignore[attr-defined] + self.data, + ephemeral=True, + ) self.participating = True except NodeExistsError: # node was already created, perhaps we are recovering from a # suspended connection self.participating = True - def leave(self): + def leave(self) -> bool: """Leave the party""" self.participating = False return self.client.retry(self._inner_leave) - def _inner_leave(self): + def _inner_leave(self) -> bool: try: - self.client.delete(self.create_path) + self.client.delete(self.create_path) # type: ignore[attr-defined] except NoNodeError: return False return True - def __len__(self): + def __len__(self) -> int: """Return a count of participating clients""" self._ensure_parent() return len(self._get_children()) - def _get_children(self): + def _get_children(self) -> list[str]: return self.client.retry(self.client.get_children, self.path) @@ -75,12 +95,14 @@ class Party(BaseParty): _NODE_NAME = "__party__" - def __init__(self, client, path, identifier=None): + def __init__( + self, client: KazooClient, path: str, identifier: str | None = None + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = uuid.uuid4().hex + self._NODE_NAME self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' data values""" self._ensure_parent() children = self._get_children() @@ -93,7 +115,7 @@ def __iter__(self): except NoNodeError: # pragma: nocover pass - def _get_children(self): + def _get_children(self) -> list[str]: children = BaseParty._get_children(self) return [c for c in children if self._NODE_NAME in c] @@ -109,12 +131,17 @@ class ShallowParty(BaseParty): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = "-".join([uuid.uuid4().hex, self.data.decode("utf-8")]) self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' identifiers""" self._ensure_parent() children = self._get_children() diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index 30d3066e..85a86676 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -9,17 +9,24 @@ See: https://github.com/python-zk/kazoo/issues/175 """ + +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent from kazoo.retry import ForceRetryError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseQueue(object): """A common base class for queue implementations.""" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. @@ -27,10 +34,10 @@ def __init__(self, client, path): self.client = client self.path = path self._entries_path = path - self.structure_paths = (self.path,) + self.structure_paths: tuple[str, ...] = (self.path,) self.ensured_path = False - def _check_put_arguments(self, value, priority=100): + def _check_put_arguments(self, value: bytes, priority: int = 100) -> None: if not isinstance(value, bytes): raise TypeError("value must be a byte string") if not isinstance(priority, int): @@ -38,14 +45,14 @@ def _check_put_arguments(self, value, priority=100): elif priority < 0 or priority > 999: raise ValueError("priority must be between 0 and 999") - def _ensure_paths(self): + def _ensure_paths(self) -> None: if not self.ensured_path: # make sure our parent / internal structure nodes exists for path in self.structure_paths: self.client.ensure_path(path) self.ensured_path = True - def __len__(self): + def __len__(self) -> int: self._ensure_paths() _, stat = self.client.retry(self.client.get, self._entries_path) return stat.children_count @@ -62,19 +69,19 @@ class Queue(BaseQueue): prefix = "entry-" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(Queue, self).__init__(client, path) - self._children = [] + self._children: list[str] = [] - def __len__(self): + def __len__(self) -> int: """Return queue size.""" return super(Queue, self).__len__() - def get(self): + def get(self) -> bytes | None: """ Get item data and remove an item from the queue. @@ -84,7 +91,7 @@ def get(self): self._ensure_paths() return self.client.retry(self._inner_get) - def _inner_get(self): + def _inner_get(self) -> bytes | None: if not self._children: self._children = self.client.retry( self.client.get_children, self.path @@ -105,7 +112,7 @@ def _inner_get(self): self._children.pop(0) return data - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an item into the queue. :param value: Byte string to put into the queue. @@ -150,26 +157,26 @@ class LockingQueue(BaseQueue): entries = "/entries" entry = "entry" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(LockingQueue, self).__init__(client, path) self.id = uuid.uuid4().hex.encode() - self.processing_element = None + self.processing_element: tuple[str, bytes] | None = None self._lock_path = self.path + self.lock self._entries_path = self.path + self.entries self.structure_paths = (self._lock_path, self._entries_path) - def __len__(self): + def __len__(self) -> int: """Returns the current length of the queue. :returns: queue size (includes locked entries count). """ return super(LockingQueue, self).__len__() - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an entry into the queue. :param value: Byte string to put into the queue. @@ -189,7 +196,7 @@ def put(self, value, priority=100): sequence=True, ) - def put_all(self, values, priority=100): + def put_all(self, values: list[bytes], priority: int = 100) -> None: """Put several entries into the queue. The action only succeeds if all entries where put into the queue. @@ -221,7 +228,7 @@ def put_all(self, values, priority=100): sequence=True, ) - def get(self, timeout=None): + def get(self, timeout: float | None = None) -> bytes | None: """Locks and gets an entry from the queue. If a previously got entry was not consumed, this method will return that entry. @@ -237,7 +244,7 @@ def get(self, timeout=None): else: return self._inner_get(timeout) - def holds_lock(self): + def holds_lock(self) -> bool: """Checks if a node still holds the lock. :returns: True if a node still holds the lock, False otherwise. @@ -251,7 +258,7 @@ def holds_lock(self): value, stat = self.client.retry(self.client.get, lock_path) return value == self.id - def consume(self): + def consume(self) -> bool: """Removes a currently processing entry from the queue. :returns: True if element was removed successfully, False otherwise. @@ -271,7 +278,7 @@ def consume(self): else: return False - def release(self): + def release(self) -> bool: """Removes the lock from currently processed item without consuming it. :returns: True if the lock was removed successfully, False otherwise. @@ -289,13 +296,13 @@ def release(self): else: return False - def _inner_get(self, timeout): + def _inner_get(self, timeout: float | None) -> bytes | None: flag = self.client.handler.event_object() lock = self.client.handler.lock_object() canceled = False value = [] - def check_for_updates(event): + def check_for_updates(event: WatchedEvent | None) -> None: if event is not None and event.type != EventType.CHILD: return with lock: @@ -330,8 +337,8 @@ def check_for_updates(event): retVal = value[0][1] return retVal - def _filter_locked(self, values, taken): - taken = set(taken) + def _filter_locked(self, values: list[str], taken: list[str]) -> list[str]: + taken = set(taken) # type: ignore[assignment] available = sorted(values) return ( available @@ -339,7 +346,7 @@ def _filter_locked(self, values, taken): else [x for x in available if x not in taken] ) - def _take(self, id_): + def _take(self, id_: str) -> tuple[str, bytes] | None: try: self.client.create( "{path}/{id}".format(path=self._lock_path, id=id_), diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index d4cb0300..aa82e547 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -10,25 +10,45 @@ will result in an exception being thrown. """ + +from __future__ import annotations + from functools import partial, wraps import logging import time import warnings +from typing import ( + Any, + List, + Callable, + Optional, + Union, + TYPE_CHECKING, + overload, +) +from typing_extensions import ParamSpec from kazoo.exceptions import ConnectionClosedError, NoNodeError, KazooException -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat from kazoo.retry import KazooRetry +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import IAsyncResult log = logging.getLogger(__name__) _STOP_WATCHING = object() +GenericArgs = ParamSpec("GenericArgs") + -def _ignore_closed(func): +def _ignore_closed( + func: Callable[GenericArgs, None] +) -> Callable[GenericArgs, None]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: GenericArgs.args, **kwargs: GenericArgs.kwargs) -> None: try: return func(*args, **kwargs) except ConnectionClosedError: @@ -37,6 +57,15 @@ def wrapper(*args, **kwargs): return wrapper +DataWatchFunc = Union[ + Callable[[Optional[bytes], Optional[ZnodeStat]], Optional[bool]], + Callable[ + [Optional[bytes], Optional[ZnodeStat], Optional[WatchedEvent]], + Optional[bool], + ], +] + + class DataWatch(object): """Watches a node for data updates and calls the specified function each time it changes @@ -88,7 +117,35 @@ def my_func(data, stat, event): """ - def __init__(self, client, path, func=None, *args, **kwargs): + @overload + def __init__( + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + ): + ... + + # FIXME This would get a @warnings.deprecated in py13+ + @overload + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): + ... + + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): """Create a data watcher for a path :param client: A zookeeper client. @@ -107,7 +164,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._func = func self._stopped = False self._run_lock = client.handler.lock_object() - self._version = None + self._version: int | None = None self._retry = KazooRetry( max_tries=None, sleep_func=client.handler.sleep_func ) @@ -132,7 +189,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._client.add_listener(self._session_watcher) self._get_data() - def __call__(self, func): + def __call__(self, func: DataWatchFunc) -> DataWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -155,16 +212,28 @@ def __call__(self, func): self._get_data() return func - def _log_func_exception(self, data, stat, event=None): + def _log_func_exception( + self, + data: bytes | None, + stat: ZnodeStat | None, + event: WatchedEvent | None = None, + ) -> None: try: # For backwards compatibility, don't send event to the # callback unless the send_event is set in constructor if not self._ever_called: self._ever_called = True try: - result = self._func(data, stat, event) + # The type ignores here are because mypy can't figure out that + # 1) self._func can't ever be None (fingers crossed) + # 2) the function can be called with 2 arguments or with 3 + # arguments (though that could possibly be done with better + # typing) + result = self._func( # type: ignore[call-arg, misc] + data, stat, event + ) except TypeError: - result = self._func(data, stat) + result = self._func(data, stat) # type: ignore[call-arg, misc] if result is False: self._stopped = True self._func = None @@ -174,7 +243,7 @@ def _log_func_exception(self, data, stat, event=None): raise @_ignore_closed - def _get_data(self, event=None): + def _get_data(self, event: WatchedEvent | None = None) -> None: # Ensure this runs one at a time, possible because the session # watcher may trigger a run with self._run_lock: @@ -183,6 +252,7 @@ def _get_data(self, event=None): initial_version = self._version + stat: ZnodeStat | None try: data, stat = self._retry( self._client.get, self._path, self._watcher @@ -210,18 +280,24 @@ def _get_data(self, event=None): if initial_version != self._version or not self._ever_called: self._log_func_exception(data, stat, event) - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: self._get_data(event=event) - def _set_watch(self, state): + def _set_watch(self, state: KazooState) -> None: with self._run_lock: self._watch_established = state - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.CONNECTED: self._client.handler.spawn(self._get_data) +ChildrenWatchFunc = Union[ + Callable[[List[str]], Optional[bool]], + Callable[[List[str], Optional[WatchedEvent]], Optional[bool]], +] + + class ChildrenWatch(object): """Watches a node for children updates and calls the specified function each time it changes @@ -253,11 +329,11 @@ def my_func(children): def __init__( self, - client, - path, - func=None, - allow_session_lost=True, - send_event=False, + client: KazooClient, + path: str, + func: ChildrenWatchFunc | None = None, + allow_session_lost: bool = True, + send_event: bool = False, ): """Create a children watcher for a path @@ -290,7 +366,7 @@ def __init__( self._watch_established = False self._allow_session_lost = allow_session_lost self._run_lock = client.handler.lock_object() - self._prior_children = None + self._prior_children: list[str] | None = None self._used = False # Register our session listener if we're going to resume @@ -301,7 +377,7 @@ def __init__( self._client.add_listener(self._session_watcher) self._get_children() - def __call__(self, func): + def __call__(self, func: ChildrenWatchFunc) -> ChildrenWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -325,7 +401,7 @@ def __call__(self, func): return func @_ignore_closed - def _get_children(self, event=None): + def _get_children(self, event: WatchedEvent | None = None) -> None: with self._run_lock: # Ensure this runs one at a time if self._stopped: return @@ -351,9 +427,16 @@ def _get_children(self, event=None): try: if self._send_event: - result = self._func(children, event) + # See comment about the type ignore here in DataWatch, + # it's the same issue where mypy can't figure out that the + # function can be called with 1 argument or with 2 + result = self._func( # type: ignore[misc] + children, event # type: ignore[call-arg] + ) else: - result = self._func(children) + result = self._func( # type: ignore[misc] + children # type: ignore[call-arg] + ) if result is False: self._stopped = True self._func = None @@ -363,11 +446,11 @@ def _get_children(self, event=None): log.exception(exc) raise - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: if event.type != "NONE": self._get_children(event) - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state in (KazooState.LOST, KazooState.SUSPENDED): self._watch_established = False elif ( @@ -408,14 +491,16 @@ class PatientChildrenWatch(object): """ - def __init__(self, client, path, time_boundary=30): + def __init__( + self, client: KazooClient, path: str, time_boundary: float = 30 + ): self.client = client self.path = path - self.children = [] + self.children: list[str] = [] self.time_boundary = time_boundary self.children_changed = client.handler.event_object() - def start(self): + def start(self) -> IAsyncResult: """Begin the watching process asynchronously :returns: An :class:`~kazoo.interfaces.IAsyncResult` instance @@ -427,7 +512,7 @@ def start(self): self.client.handler.spawn(self._inner_start) return asy - def _inner_start(self): + def _inner_start(self) -> None: try: while True: async_result = self.client.handler.async_result() @@ -447,6 +532,8 @@ def _inner_start(self): except Exception as exc: self.asy.set_exception(exc) - def _children_watcher(self, async_result, event): + def _children_watcher( + self, async_result: IAsyncResult, event: WatchedEvent + ) -> None: self.children_changed.set() async_result.set(time.monotonic()) diff --git a/kazoo/retry.py b/kazoo/retry.py index fb9e8fc7..9e4e0c9a 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import random import time +from typing import Any, Callable, TypeVar from kazoo.exceptions import ( ConnectionClosedError, @@ -10,7 +13,6 @@ SessionExpiredError, ) - log = logging.getLogger(__name__) @@ -37,18 +39,20 @@ class KazooRetry(object): EXPIRED_EXCEPTIONS = (SessionExpiredError,) + RETRY_RETURN = TypeVar("RETRY_RETURN") + def __init__( self, - max_tries=1, - delay=0.1, - backoff=2, - max_jitter=0.4, - max_delay=60.0, - ignore_expire=True, - sleep_func=time.sleep, - deadline=None, - interrupt=None, - ): + max_tries: int | None = 1, + delay: float = 0.1, + backoff: int = 2, + max_jitter: float = 0.4, + max_delay: float = 60.0, + ignore_expire: bool = True, + sleep_func: Callable[[float], None] = time.sleep, + deadline: float | None = None, + interrupt: Callable[[], bool] | None = None, + ) -> None: """Create a :class:`KazooRetry` instance for retrying function calls. @@ -81,20 +85,22 @@ def __init__( self._attempts = 0 self._cur_delay = delay self.deadline = deadline - self._cur_stoptime = None + self._cur_stoptime: float | None = None self.sleep_func = sleep_func - self.retry_exceptions = self.RETRY_EXCEPTIONS + self.retry_exceptions: tuple[ + type[Exception], ... + ] = self.RETRY_EXCEPTIONS self.interrupt = interrupt if ignore_expire: self.retry_exceptions += self.EXPIRED_EXCEPTIONS - def reset(self): + def reset(self) -> None: """Reset the attempt counter""" self._attempts = 0 self._cur_delay = self.delay self._cur_stoptime = None - def copy(self): + def copy(self) -> KazooRetry: """Return a clone of this retry manager""" obj = KazooRetry( max_tries=self.max_tries, @@ -109,7 +115,12 @@ def copy(self): obj.retry_exceptions = self.retry_exceptions return obj - def __call__(self, func, *args, **kwargs): + def __call__( + self, + func: Callable[..., RETRY_RETURN], + *args: Any, + **kwargs: Any, + ) -> RETRY_RETURN: """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/kazoo/security.py b/kazoo/security.py index 68399445..1b383795 100644 --- a/kazoo/security.py +++ b/kazoo/security.py @@ -1,14 +1,19 @@ """Kazoo Security""" + +from __future__ import annotations + from base64 import b64encode -from collections import namedtuple import hashlib +from typing import NamedTuple # Represents a Zookeeper ID and ACL object -Id = namedtuple("Id", "scheme id") +class Id(NamedTuple): + scheme: str + id: str -class ACL(namedtuple("ACL", "perms id")): +class ACL(NamedTuple): """An ACL for a Zookeeper Node An ACL object is created by using an :class:`Id` object along with @@ -17,8 +22,11 @@ class ACL(namedtuple("ACL", "perms id")): the desired scheme, id, and permissions. """ + perms: int + id: Id + @property - def acl_list(self): + def acl_list(self) -> list[str]: perms = [] if self.perms & Permissions.ALL == Permissions.ALL: perms.append("ALL") @@ -35,7 +43,7 @@ def acl_list(self): perms.append("ADMIN") return perms - def __repr__(self): + def __repr__(self) -> str: return "ACL(perms=%r, acl_list=%s, id=%r)" % ( self.perms, self.acl_list, @@ -62,7 +70,7 @@ class Permissions(object): READ_ACL_UNSAFE = [ACL(Permissions.READ, ANYONE_ID_UNSAFE)] -def make_digest_acl_credential(username, password): +def make_digest_acl_credential(username: str, password: str) -> str: """Create a SHA1 digest credential. .. note:: @@ -80,15 +88,15 @@ def make_digest_acl_credential(username, password): def make_acl( - scheme, - credential, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + scheme: str, + credential: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Given a scheme and credential, return an :class:`ACL` object appropriate for use with Kazoo. @@ -131,15 +139,15 @@ def make_acl( def make_digest_acl( - username, - password, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + username: str, + password: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Create a digest ACL for Zookeeper with the given permissions This method combines :meth:`make_digest_acl_credential` and diff --git a/pyproject.toml b/pyproject.toml index db3890c5..48f3b5fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,10 @@ ignore_missing_imports = false # Disallow dynamic typing disallow_any_unimported = true disallow_any_expr = false -disallow_any_decorated = true -disallow_any_explicit = true +# disallow_any_decorated is disabled because it produces an error for almost +# every decorated function (at least in python3.8) +disallow_any_decorated = false +disallow_any_explicit = false # true disallow_any_generics = true disallow_subclassing_any = true @@ -81,7 +83,7 @@ hide_error_codes = false pretty = true color_output = true error_summary = true -show_absolute_path = true +show_absolute_path = false # Miscellaneous warn_unused_configs = true @@ -90,33 +92,13 @@ verbosity = 0 # FIXME: As type annotations are introduced, please remove the appropriate # ignore_errors flag below. New modules should NOT be added here! +# This is a temporary workaround for the fact that mypy can +# produce different errors in 3.8 and 3.14, and I want to avoid code changes +# as much as possible. +disable_error_code = [ 'unused-ignore' ] + [[tool.mypy.overrides]] module = [ - 'kazoo.client', - 'kazoo.exceptions', - 'kazoo.handlers.eventlet', - 'kazoo.handlers.gevent', - 'kazoo.handlers.threading', - 'kazoo.handlers.utils', - 'kazoo.hosts', - 'kazoo.interfaces', - 'kazoo.loggingsupport', - 'kazoo.protocol.connection', - 'kazoo.protocol.paths', - 'kazoo.protocol.serialization', - 'kazoo.protocol.states', - 'kazoo.recipe.barrier', - 'kazoo.recipe.cache', - 'kazoo.recipe.counter', - 'kazoo.recipe.election', - 'kazoo.recipe.lease', - 'kazoo.recipe.lock', - 'kazoo.recipe.partitioner', - 'kazoo.recipe.party', - 'kazoo.recipe.queue', - 'kazoo.recipe.watchers', - 'kazoo.retry', - 'kazoo.security', 'kazoo.testing.common', 'kazoo.testing.harness', 'kazoo.tests.conftest', @@ -146,6 +128,17 @@ module = [ 'kazoo.tests.test_utils', 'kazoo.tests.test_watchers', 'kazoo.tests.util', - 'kazoo.version' ] ignore_errors = true + +[[tool.mypy.overrides]] + module = ["eventlet.*"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["gevent.thread"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["puresasl.*"] + follow_untyped_imports = true diff --git a/setup.cfg b/setup.cfg index 2e771f17..4dfd5816 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,6 +65,7 @@ test = eventlet>=0.17.1 ; implementation_name!='pypy' pyjks pyopenssl + typing-extensions eventlet = eventlet>=0.17.1 @@ -81,6 +82,10 @@ docs = typing = mypy>=0.991 + types-gevent + +other = + typing-extensions alldeps = %(dev)s @@ -88,5 +93,6 @@ alldeps = %(gevent)s %(sasl)s %(docs)s + %(other)s %(typing)s