From ec7dc06bc9be840c34bca60171d25f79f183033e Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Thu, 23 Apr 2026 08:12:18 +0100 Subject: [PATCH 1/8] Add type checking --- .coveragerc | 7 + .flake8 | 14 +- .gitignore | 6 +- ensure-zookeeper-env.sh | 7 +- kazoo/client.py | 472 +++++++++++++++++++++++--------- kazoo/exceptions.py | 11 +- kazoo/handlers/eventlet.py | 60 ++-- kazoo/handlers/gevent.py | 52 ++-- kazoo/handlers/threading.py | 55 ++-- kazoo/handlers/utils.py | 142 ++++++---- kazoo/hosts.py | 11 +- kazoo/interfaces.py | 133 +++++++-- kazoo/protocol/connection.py | 209 ++++++++++---- kazoo/protocol/paths.py | 12 +- kazoo/protocol/serialization.py | 239 ++++++++++------ kazoo/protocol/states.py | 66 +++-- kazoo/recipe/barrier.py | 39 ++- kazoo/recipe/cache.py | 106 ++++--- kazoo/recipe/counter.py | 41 ++- kazoo/recipe/election.py | 21 +- kazoo/recipe/lease.py | 50 ++-- kazoo/recipe/lock.py | 154 ++++++++--- kazoo/recipe/partitioner.py | 79 ++++-- kazoo/recipe/party.py | 56 +++- kazoo/recipe/queue.py | 57 ++-- kazoo/recipe/watchers.py | 112 ++++++-- kazoo/retry.py | 40 +-- kazoo/security.py | 56 ++-- pyproject.toml | 61 +++-- setup.cfg | 1 + 30 files changed, 1623 insertions(+), 746 deletions(-) diff --git a/.coveragerc b/.coveragerc index d84a6fc8b..0c53d00b0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,10 @@ include = omit = kazoo/tests/* kazoo/testing/* + +# Note - this is a copy of the default exclusions from coverage 7.10.1 +[report] +exclude_lines = + #\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER) + ^\s*(((async )?def .*?)?\)(\s*->.*?)?:\s*)?\.\.\.\s*(#|$) + if (typing\.)?TYPE_CHECKING: diff --git a/.flake8 b/.flake8 index ba8a3d67f..839f03220 100644 --- a/.flake8 +++ b/.flake8 @@ -3,12 +3,18 @@ builtins = _ exclude = .git, __pycache__, - .venv/,venv/, + .venv/, + venv*/, .tox/, - build/,dist/,*egg, + build/, + dist/, + *egg, docs/conf.py, - zookeeper/ + zookeeper/, # See black's documentation for E203 max-line-length = 79 -extend-ignore = BLK100,E203 +# I am not sure what version of flake8 hound is using +# but it gives a lot of undefined names for comments (F821) +# and redefinition of unused variables (F811) +extend-ignore = BLK100,E203,F811,F821 diff --git a/.gitignore b/.gitignore index cc73e2674..94c5a4e2e 100644 --- a/.gitignore +++ b/.gitignore @@ -30,10 +30,14 @@ zookeeper/ .project .pydevproject .tox -venv +venv*/ /.settings /.metadata +__pycache__/ !.gitignore !.git-blame-ignore-revs +.vscode/settings.json +.*_cache/ +coverage.xml diff --git a/ensure-zookeeper-env.sh b/ensure-zookeeper-env.sh index ae0272a21..6296f3b04 100755 --- a/ensure-zookeeper-env.sh +++ b/ensure-zookeeper-env.sh @@ -4,10 +4,11 @@ set -e HERE=`pwd` ZOO_BASE_DIR="$HERE/zookeeper" -ZOOKEEPER_VERSION=${ZOOKEEPER_VERSION:-3.4.14} +export ZOOKEEPER_VERSION=${ZOOKEEPER_VERSION:-3.6.4} ZOOKEEPER_PATH="$ZOO_BASE_DIR/$ZOOKEEPER_VERSION" -ZOOKEEPER_PREFIX=${ZOOKEEPER_PREFIX} -ZOOKEEPER_SUFFIX=${ZOOKEEPER_SUFFIX} +ZOOKEEPER_PREFIX=${ZOOKEEPER_PREFIX:-apache} +ZOOKEEPER_SUFFIX=${ZOOKEEPER_SUFFIX:--bin} +ZOOKEEPER_LIB=${ZOOKEEPER_LIB:-lib} ZOO_MIRROR_URL="https://archive.apache.org/dist" diff --git a/kazoo/client.py b/kazoo/client.py index 3029d1c5f..ea9f5856e 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -1,4 +1,7 @@ """Kazoo Zookeeper Client""" + +from __future__ import annotations + from collections import defaultdict, deque from functools import partial import inspect @@ -6,6 +9,17 @@ from os.path import split import re import warnings +from typing import ( + Any, + Callable, + Literal, + Optional, + Sequence, + Set, + Union, + overload, + TYPE_CHECKING, +) from kazoo.exceptions import ( AuthFailedError, @@ -63,6 +77,10 @@ from kazoo.recipe.queue import Queue, LockingQueue from kazoo.recipe.watchers import ChildrenWatch, DataWatch +if TYPE_CHECKING: + from kazoo.interfaces import IAsyncResult, IHandler + from kazoo.protocol.states import ZnodeStat + CLOSED_STATES = ( KeeperState.EXPIRED_SESSION, @@ -88,6 +106,12 @@ retry_max_delay="max_delay", ) +# Signature for functions called by add_listener +ListenerFunc = Callable[[KazooState], Optional[bool]] + +# Signatures for get, get_children and exists watches +WatchFunc = Callable[[WatchedEvent], Optional[bool]] + class KazooClient(object): """An Apache Zookeeper Python client supporting alternate callback @@ -102,27 +126,27 @@ class KazooClient(object): def __init__( self, - hosts="127.0.0.1:2181", - timeout=10.0, - client_id=None, - handler=None, - default_acl=None, - auth_data=None, - sasl_options=None, - read_only=None, - randomize_hosts=True, - connection_retry=None, - command_retry=None, - logger=None, - keyfile=None, - keyfile_password=None, - certfile=None, - ca=None, - use_ssl=False, - verify_certs=True, - check_hostname=False, - **kwargs, - ): + hosts: Union[str, list[str]] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: Optional[tuple] = None, + handler: Optional[IHandler] = None, + default_acl: Optional[Sequence[ACL]] = None, + auth_data: Optional[set] = None, + sasl_options: Optional[dict] = None, + read_only: Optional[bool] = None, + randomize_hosts: bool = True, + connection_retry: Optional[Union[KazooRetry, dict]] = None, + command_retry: Optional[Union[KazooRetry, dict]] = None, + logger: Optional[logging.Logger] = None, + keyfile: Optional[str] = None, + keyfile_password: Optional[str] = None, + certfile: Optional[str] = None, + ca: Optional[str] = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Any, + ) -> None: """Create a :class:`KazooClient` instance. All time arguments are in seconds. @@ -234,8 +258,15 @@ def __init__( self.auth_data = auth_data if auth_data else set([]) self.default_acl = default_acl self.randomize_hosts = randomize_hosts - self.hosts = None - self.chroot = None + # Note: hosts and chroot are set by set_hosts, which also checks for + # chroot changes at runtime, so we initialize them to None here to + # avoid confusion with the empty string that set_hosts would set them + # to. This is massively hacky as set_hosts is only called from here + # anyway, but I want to make this change minimally invasive. + # we should really do self.hosts, self.chroot = self.set_hosts(hosts) + # and have set_hosts return the hosts and chroot + self.hosts: list[tuple[str, int]] = None # type: ignore[assignment] + self.chroot: str = None # type: ignore[assignment] self.set_hosts(hosts) self.use_ssl = use_ssl @@ -247,11 +278,15 @@ def __init__( self.ca = ca # Curator like simplified state tracking, and listeners for # state transitions - self._state = KeeperState.CLOSED - self.state = KazooState.LOST - self.state_listeners = set() - self._child_watchers = defaultdict(set) - self._data_watchers = defaultdict(set) + self._state: KeeperState = KeeperState.CLOSED + self.state: KazooState = KazooState.LOST + self.state_listeners: set[ListenerFunc] = set() + self._child_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) + self._data_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) self._reset() self.read_only = read_only @@ -272,7 +307,14 @@ def __init__( self._stopped.set() self._writer_stopped.set() - self.retry = self._conn_retry = None + # This is kind of gross but we need to set these to something so that + # the type checker will understand that they are set by the time they + # are used and that they have the right type. + # We would do better to use a few variables/functions instead of + # overloading self.retry but this is a bit less invasive to the code + # and the type checker can understand it with a few hacks + self.retry: KazooRetry = None # type: ignore[assignment] + self._conn_retry: KazooRetry = None # type: ignore[assignment] if type(connection_retry) is dict: self._conn_retry = KazooRetry(**connection_retry) @@ -299,7 +341,11 @@ def __init__( ) if self.retry is None or self._conn_retry is None: - old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS) + # Note: because of the hacks at line 280, mypy thinks this is + # unreachable + old_retry_keys = dict( # type: ignore[unreachable] + _RETRY_COMPAT_DEFAULTS + ) for key in old_retry_keys: try: old_retry_keys[key] = kwargs.pop(key) @@ -320,11 +366,13 @@ def __init__( if self._conn_retry is None: self._conn_retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) if self.retry is None: self.retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) # Managing legacy SASL options @@ -372,10 +420,18 @@ def __init__( # to avoid shared retry counts self._retry = self.retry - def _retry(*args, **kwargs): + # also this should be called with a func that returns nothing as the + # 1st argument. + def _retry(*args: Any, **kwargs: Any) -> Any: return self._retry.copy()(*args, **kwargs) - self.retry = _retry + # (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", + # variable has type "KazooRetry") so basically self.retry needs to be + # set to that and then the type checker will understand that + # self.retry.copy() is a valid call. This is just a mess and needs the + # code rearranging to be more mypy friendly but this is the least + # invasive way to do it for now + self.retry = _retry # type: ignore[assignment] self.Barrier = partial(Barrier, self) self.Counter = partial(Counter, self) @@ -402,18 +458,18 @@ def _retry(*args, **kwargs): % (kwargs.keys(),) ) - def _reset(self): + def _reset(self) -> None: """Resets a variety of client states for a new connection.""" - self._queue = deque() - self._pending = deque() + self._queue: deque = deque() + self._pending: deque = deque() self._reset_watchers() self._reset_session() self.last_zxid = 0 self._protocol_version = None - def _reset_watchers(self): - watchers = [] + def _reset_watchers(self) -> None: + watchers: list[WatchFunc] = [] for child_watchers in self._child_watchers.values(): watchers.extend(child_watchers) @@ -427,12 +483,12 @@ def _reset_watchers(self): for watch in watchers: self.handler.dispatch_callback(Callback("watch", watch, (ev,))) - def _reset_session(self): + def _reset_session(self) -> None: self._session_id = None self._session_passwd = b"\x00" * 16 @property - def client_state(self): + def client_state(self) -> KeeperState: """Returns the last Zookeeper client state This is the non-simplified state information and is generally @@ -442,7 +498,7 @@ def client_state(self): return self._state @property - def client_id(self): + def client_id(self) -> Optional[tuple]: """Returns the client id for this Zookeeper session if connected. @@ -455,12 +511,16 @@ def client_id(self): return None @property - def connected(self): + def connected(self) -> bool: """Returns whether the Zookeeper connection has been established.""" return self._live.is_set() - def set_hosts(self, hosts, randomize_hosts=None): + def set_hosts( + self, + hosts: Union[str, list[str]], + randomize_hosts: Optional[bool] = None, + ) -> None: """sets the list of hosts used by this client. This function accepts the same format hosts parameter as the init @@ -504,7 +564,7 @@ def set_hosts(self, hosts, randomize_hosts=None): self.chroot = new_chroot - def add_listener(self, listener): + def add_listener(self, listener: ListenerFunc) -> None: """Add a function to be called for connection state changes. This function will be called with a @@ -519,15 +579,20 @@ def add_listener(self, listener): should be used so that the listener can return immediately. """ - if not (listener and callable(listener)): + # This check should be unnecessary but protects against people who are + # not using type checkers and accidentally passing in something that + # isn't callable. It should be removed. + if not ( + listener and callable(listener) # type: ignore[truthy-function] + ): raise ConfigurationError("listener must be callable") self.state_listeners.add(listener) - def remove_listener(self, listener): + def remove_listener(self, listener: ListenerFunc) -> None: """Remove a listener function""" self.state_listeners.discard(listener) - def _make_state_change(self, state): + def _make_state_change(self, state: KazooState) -> None: # skip if state is current if self.state == state: return @@ -544,7 +609,7 @@ def _make_state_change(self, state): except Exception: self.logger.exception("Error in connection state listener") - def _session_callback(self, state): + def _session_callback(self, state: KeeperState) -> None: if state == self._state: return @@ -581,9 +646,10 @@ def _session_callback(self, state): self._make_state_change(KazooState.SUSPENDED) self._reset_watchers() - def _notify_pending(self, state): + def _notify_pending(self, state: KeeperState) -> None: """Used to clear a pending response queue and request queue during connection drops.""" + exc: KazooException if state == KeeperState.AUTH_FAILED: exc = AuthFailedError() elif state == KeeperState.EXPIRED_SESSION: @@ -607,7 +673,7 @@ def _notify_pending(self, state): except IndexError: break - def _safe_close(self): + def _safe_close(self) -> None: self.handler.stop() timeout = self._session_timeout // 1000 if timeout < 10: @@ -618,7 +684,9 @@ def _safe_close(self): "and wouldn't close after %s seconds" % timeout ) - def _call(self, request, async_object): + def _call( + self, request: object, async_object: IAsyncResult + ) -> Optional[bool]: """Ensure the client is in CONNECTED or SUSPENDED state and put the request in the queue if it is. @@ -647,14 +715,16 @@ def _call(self, request, async_object): async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None try: write_sock.send(b"\0") except: # NOQA async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None - def start(self, timeout=15): + def start(self, timeout: float = 15.0) -> None: """Initiate connection to ZK. :param timeout: Time in seconds to wait for connection to @@ -678,7 +748,7 @@ def start(self, timeout=15): "should be created before normal use." ) - def start_async(self): + def start_async(self) -> Any: """Asynchronously initiate connection to ZK. :returns: An event object that can be checked to see if the @@ -705,7 +775,7 @@ def start_async(self): self._connection.start() return self._live - def stop(self): + def stop(self) -> None: """Gracefully stop this Zookeeper session. This method can be called while a reconnection attempt is in @@ -723,16 +793,20 @@ def stop(self): self._stopped.set() self._queue.append((CloseInstance, None)) try: - self._connection._write_sock.send(b"\0") + # This assert should never fail since the connection should + # have been started but I'm not sure how to persaude mypy of that + self._connection._write_sock.send( # type: ignore[union-attr] + b"\0" + ) finally: self._safe_close() - def restart(self): + def restart(self) -> None: """Stop and restart the Zookeeper session.""" self.stop() self.start() - def close(self): + def close(self) -> None: """Free any resources held by the client. This method should be called on a stopped client before it is @@ -742,7 +816,7 @@ def close(self): """ self._connection.close() - def command(self, cmd=b"ruok"): + def command(self, cmd: bytes = b"ruok") -> str: """Sent a management command to the current ZK server. Examples are `ruok`, `envi` or `stat`. @@ -761,8 +835,18 @@ def command(self, cmd=b"ruok"): if not self._live.is_set(): raise ConnectionLoss("No connection to server") - peer = self._connection._socket.getpeername()[:2] - peer_host = self._connection._socket.getpeername()[1] + # Need a way of persauding mypy that the connection is live and thus + # the socket is not None + peer = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + :2 + ] + ) + peer_host = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + 1 + ] + ) sock = self.handler.create_connection( peer, hostname=peer_host, @@ -780,7 +864,7 @@ def command(self, cmd=b"ruok"): sock.close() return result.decode("utf-8", "replace") - def server_version(self, retries=3): + def server_version(self, retries: int = 3) -> tuple: """Get the version of the currently connected ZK server. :returns: The server version, for example (3, 4, 3). @@ -790,7 +874,7 @@ def server_version(self, retries=3): """ - def _try_fetch(): + def _try_fetch() -> Optional[tuple[int, ...]]: data = self.command(b"envi") data_parsed = {} for line in data.splitlines(): @@ -804,13 +888,19 @@ def _try_fetch(): if k: data_parsed[k] = v version = data_parsed.get(ENVI_VERSION_KEY, "") - version_digits = ENVI_VERSION.match(version).group(1) + # a) if you get an unexpected answer, you'll crash + # b) not changing the code, so just ignoring the type error + version_digits = ENVI_VERSION.match( + version + ).group( # type: ignore[union-attr] + 1 + ) try: return tuple([int(d) for d in version_digits.split(".")]) except ValueError: return None - def _is_valid(version): + def _is_valid(version: Optional[tuple[int, ...]]) -> bool: # All zookeeper versions should have at least major.minor # version numbers; if we get one that doesn't it is likely not # correct and was truncated... @@ -818,21 +908,28 @@ def _is_valid(version): return True return False + # A better way of doing this would be to put the initial _try_fetch in + # the loop and inline _is_valid but I want to minimise code changes + # Try 1 + retries amount of times to get a version that we know # will likely be acceptable... version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + # and the next 2 suppress should include return-value + # but hound is broken + return version # type: ignore for _i in range(0, retries): version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + return version # type: ignore raise KazooException( "Unable to fetch useable server" " version after trying %s times" % (1 + max(0, retries)) ) - def add_auth(self, scheme, credential): + def add_auth(self, scheme: str, credential: str) -> bool: """Send credentials to server. :param scheme: authentication scheme (default supported: @@ -849,7 +946,7 @@ def add_auth(self, scheme, credential): """ return self.add_auth_async(scheme, credential).get() - def add_auth_async(self, scheme, credential): + def add_auth_async(self, scheme: str, credential: str) -> IAsyncResult: """Asynchronously send credentials to server. Takes the same arguments as :meth:`add_auth`. @@ -868,7 +965,7 @@ def add_auth_async(self, scheme, credential): self._call(Auth(0, scheme, credential), async_result) return async_result - def unchroot(self, path): + def unchroot(self, path: str) -> str: """Strip the chroot if applicable from the path.""" if not self.chroot: return path @@ -879,7 +976,7 @@ def unchroot(self, path): else: return path - def sync_async(self, path): + def sync_async(self, path: str) -> IAsyncResult: """Asynchronous sync. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -888,10 +985,10 @@ def sync_async(self, path): async_result = self.handler.async_result() @wrap(async_result) - def _sync_completion(result): + def _sync_completion(result: IAsyncResult) -> str: return self.unchroot(result.get()) - def _do_sync(): + def _do_sync() -> None: result = self.handler.async_result() self._call(Sync(_prefix_root(self.chroot, path)), result) result.rawlink(_sync_completion) @@ -899,7 +996,7 @@ def _do_sync(): _do_sync() return async_result - def sync(self, path): + def sync(self, path: str) -> str: """Sync, blocks until response is acknowledged. Flushes channel between process and leader. @@ -915,16 +1012,42 @@ def sync(self, path): """ return self.sync_async(path).get() + @overload + def create( + self, + path: str, + value: bytes = b"", + acl: Optional[Sequence[ACL]] = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[False] = False, + ) -> str: + ... + + @overload + def create( + self, + path: str, + value: bytes = b"", + acl: Optional[Sequence[ACL]] = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[True] = True, + ) -> tuple[str, ZnodeStat]: + ... + def create( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes = b"", + acl: Optional[Sequence[ACL]] = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> Union[str, tuple[str, ZnodeStat]]: """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1015,14 +1138,14 @@ def create( def create_async( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes = b"", + acl: Optional[Sequence[ACL]] = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously create a ZNode. Takes the same arguments as :meth:`create`. @@ -1066,11 +1189,16 @@ def create_async( async_result = self.handler.async_result() @capture_exceptions(async_result) - def do_create(): + def do_create() -> None: result = self._create_async_inner( path, value, - acl, + # The way acl is constructed ends up confusing mypy, which + # thinks that acl can be None here, even though the code + # above ensures that if acl is None, it gets set to + # OPEN_ACL_UNSAFE, so we ignore the type error here. + # behaves differently in python3.8 and python3.14, sigh. + acl, # type: ignore[arg-type] flags, trailing=sequence, include_data=include_data, @@ -1078,12 +1206,14 @@ def do_create(): result.rawlink(create_completion) @capture_exceptions(async_result) - def retry_completion(result): + def retry_completion(result: IAsyncResult) -> None: result.get() do_create() @wrap(async_result) - def create_completion(result): + def create_completion( + result: IAsyncResult, + ) -> Optional[Union[str, tuple[str, ZnodeStat]]]: try: if include_data: new_path, stat = result.get() @@ -1098,18 +1228,22 @@ def create_completion(result): else: parent, _ = split(path) self.ensure_path_async(parent, acl).rawlink(retry_completion) + return None do_create() return async_result def _create_async_inner( - self, path, value, acl, flags, trailing=False, include_data=False - ): + self, + path: str, + value: bytes, + acl: Sequence[ACL], + flags: int, + trailing: bool = False, + include_data: bool = False, + ) -> IAsyncResult: async_result = self.handler.async_result() - if include_data: - opcode = Create2 - else: - opcode = Create + opcode = Create2 if include_data else Create call_result = self._call( opcode( @@ -1126,10 +1260,15 @@ def _create_async_inner( # exception upwards to the do_create function in # KazooClient.create so that it gets set on the correct # async_result object - raise async_result.exception + # Note: Do we actually need call_result? It seems like we could + # just check the state of the exception, and avoid the typing + # stuff. + raise async_result.exception # type: ignore[misc] return async_result - def ensure_path(self, path, acl=None): + def ensure_path( + self, path: str, acl: Optional[Sequence[ACL]] = None + ) -> bool: """Recursively create a path if it doesn't exist. :param path: Path of node. @@ -1138,7 +1277,9 @@ def ensure_path(self, path, acl=None): """ return self.ensure_path_async(path, acl).get() - def ensure_path_async(self, path, acl=None): + def ensure_path_async( + self, path: str, acl: Optional[Sequence[ACL]] = None + ) -> IAsyncResult: """Recursively create a path asynchronously if it doesn't exist. Takes the same arguments as :meth:`ensure_path`. @@ -1151,19 +1292,21 @@ def ensure_path_async(self, path, acl=None): async_result = self.handler.async_result() @wrap(async_result) - def create_completion(result): + def create_completion(result: Any) -> bool: try: return result.get() except NodeExistsError: return True @capture_exceptions(async_result) - def prepare_completion(next_path, result): + def prepare_completion(next_path: str, result: IAsyncResult) -> None: result.get() self.create_async(next_path, acl=acl).rawlink(create_completion) @wrap(async_result) - def exists_completion(path, result): + def exists_completion( + path: str, result: IAsyncResult + ) -> Optional[Literal[True]]: if result.get(): return True parent, node = split(path) @@ -1173,12 +1316,15 @@ def exists_completion(path, result): ) else: self.create_async(path, acl=acl).rawlink(create_completion) + return None self.exists_async(path).rawlink(partial(exists_completion, path)) return async_result - def exists(self, path, watch=None): + def exists( + self, path: str, watch: Optional[WatchFunc] = None + ) -> Optional[ZnodeStat]: """Check if a node exists. If a watch is provided, it will be left on the node with the @@ -1200,7 +1346,9 @@ def exists(self, path, watch=None): """ return self.exists_async(path, watch=watch).get() - def exists_async(self, path, watch=None): + def exists_async( + self, path: str, watch: Optional[WatchFunc] = None + ) -> IAsyncResult: """Asynchronously check if a node exists. Takes the same arguments as :meth:`exists`. @@ -1218,7 +1366,9 @@ def exists_async(self, path, watch=None): ) return async_result - def get(self, path, watch=None): + def get( + self, path: str, watch: Optional[WatchFunc] = None + ) -> tuple[bytes, ZnodeStat]: """Get the value of a node. If a watch is provided, it will be left on the node with the @@ -1243,7 +1393,9 @@ def get(self, path, watch=None): """ return self.get_async(path, watch=watch).get() - def get_async(self, path, watch=None): + def get_async( + self, path: str, watch: Optional[WatchFunc] = None + ) -> IAsyncResult: """Asynchronously get the value of a node. Takes the same arguments as :meth:`get`. @@ -1261,7 +1413,12 @@ def get_async(self, path, watch=None): ) return async_result - def get_children(self, path, watch=None, include_data=False): + def get_children( + self, + path: str, + watch: Optional[WatchFunc] = None, + include_data: bool = False, + ) -> list[str]: """Get a list of child nodes of a path. If a watch is provided it will be left on the node with the @@ -1299,7 +1456,12 @@ def get_children(self, path, watch=None, include_data=False): path, watch=watch, include_data=include_data ).get() - def get_children_async(self, path, watch=None, include_data=False): + def get_children_async( + self, + path: str, + watch: Optional[WatchFunc] = None, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously get a list of child nodes of a path. Takes the same arguments as :meth:`get_children`. @@ -1314,6 +1476,7 @@ def get_children_async(self, path, watch=None, include_data=False): raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() + req: Union[GetChildren, GetChildren2] if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1321,7 +1484,7 @@ def get_children_async(self, path, watch=None, include_data=False): self._call(req, async_result) return async_result - def get_acls(self, path): + def get_acls(self, path: str) -> tuple[list[ACL], ZnodeStat]: """Return the ACL and stat of the node of the given path. :param path: Path of the node. @@ -1341,7 +1504,7 @@ def get_acls(self, path): """ return self.get_acls_async(path).get() - def get_acls_async(self, path): + def get_acls_async(self, path: str) -> IAsyncResult: """Return the ACL and stat of the node of the given path. Takes the same arguments as :meth:`get_acls`. @@ -1355,7 +1518,9 @@ def get_acls_async(self, path): self._call(GetACL(_prefix_root(self.chroot, path)), async_result) return async_result - def set_acls(self, path, acls, version=-1): + def set_acls( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> ZnodeStat: """Set the ACL for the node of the given path. Set the ACL for the node of the given path if such a node @@ -1384,7 +1549,9 @@ def set_acls(self, path, acls, version=-1): """ return self.set_acls_async(path, acls, version).get() - def set_acls_async(self, path, acls, version=-1): + def set_acls_async( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> IAsyncResult: """Set the ACL for the node of the given path. Takes the same arguments as :meth:`set_acls`. @@ -1407,7 +1574,9 @@ def set_acls_async(self, path, acls, version=-1): ) return async_result - def set(self, path, value, version=-1): + def set( + self, path: str, value: Optional[bytes], version: int = -1 + ) -> ZnodeStat: """Set the value of a node. If the version of the node being updated is newer than the @@ -1442,7 +1611,9 @@ def set(self, path, value, version=-1): """ return self.set_async(path, value, version).get() - def set_async(self, path, value, version=-1): + def set_async( + self, path: str, value: Optional[bytes], version: int = -1 + ) -> IAsyncResult: """Set the value of a node. Takes the same arguments as :meth:`set`. @@ -1463,7 +1634,7 @@ def set_async(self, path, value, version=-1): ) return async_result - def transaction(self): + def transaction(self) -> TransactionRequest: """Create and return a :class:`TransactionRequest` object Creates a :class:`TransactionRequest` object. A Transaction can @@ -1480,7 +1651,12 @@ def transaction(self): """ return TransactionRequest(self) - def delete(self, path, version=-1, recursive=False): + # This should not return anything. No return value is documented, and the + # two called functions return different things. AFAICT. But for now I + # want to minimise code changes + def delete( + self, path: str, version: int = -1, recursive: bool = False + ) -> Any: """Delete a node. The call will succeed if such a node exists, and the given @@ -1518,7 +1694,7 @@ def delete(self, path, version=-1, recursive=False): else: return self.delete_async(path, version).get() - def delete_async(self, path, version=-1): + def delete_async(self, path: str, version: int = -1) -> IAsyncResult: """Asynchronously delete a node. Takes the same arguments as :meth:`delete`, with the exception of `recursive`. @@ -1535,7 +1711,7 @@ def delete_async(self, path, version=-1): ) return async_result - def _delete_recursive(self, path): + def _delete_recursive(self, path: str) -> Optional[Literal[True]]: try: children = self.get_children(path) except NoNodeError: @@ -1553,8 +1729,15 @@ def _delete_recursive(self, path): self.delete(path) except NoNodeError: # pragma: nocover pass + return None - def reconfig(self, joining, leaving, new_members, from_config=-1): + def reconfig( + self, + joining: Optional[str], + leaving: Optional[str], + new_members: Optional[str], + from_config: int = -1, + ) -> tuple[bytes, ZnodeStat]: """Reconfig a cluster. This call will succeed if the cluster was reconfigured accordingly. @@ -1627,7 +1810,13 @@ def reconfig(self, joining, leaving, new_members, from_config=-1): ) return result.get() - def reconfig_async(self, joining, leaving, new_members, from_config): + def reconfig_async( + self, + joining: Optional[str], + leaving: Optional[str], + new_members: Optional[str], + from_config: int, + ) -> IAsyncResult: """Asynchronously reconfig a cluster. Takes the same arguments as :meth:`reconfig`. @@ -1674,14 +1863,19 @@ class TransactionRequest(object): """ - def __init__(self, client): + def __init__(self, client: KazooClient): self.client = client - self.operations = [] + self.operations: list[Any] = [] self.committed = False def create( - self, path, value=b"", acl=None, ephemeral=False, sequence=False - ): + self, + path: str, + value: bytes = b"", + acl: Optional[Sequence[ACL]] = None, + ephemeral: bool = False, + sequence: bool = False, + ) -> None: """Add a create ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.create`, with the exception of `makepath`. @@ -1718,7 +1912,7 @@ def create( None, ) - def delete(self, path, version=-1): + def delete(self, path: str, version: int = -1) -> None: """Add a delete ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.delete`, with the exception of `recursive`. @@ -1730,7 +1924,7 @@ def delete(self, path, version=-1): raise TypeError("Invalid type for 'version' (int expected)") self._add(Delete(_prefix_root(self.client.chroot, path), version)) - def set_data(self, path, value, version=-1): + def set_data(self, path: str, value: bytes, version: int = -1) -> None: """Add a set ZNode value to the transaction. Takes the same arguments as :meth:`KazooClient.set`. @@ -1745,7 +1939,7 @@ def set_data(self, path, value, version=-1): SetData(_prefix_root(self.client.chroot, path), value, version) ) - def check(self, path, version): + def check(self, path: str, version: int) -> None: """Add a Check Version to the transaction. This command will fail and abort a transaction if the path @@ -1760,7 +1954,7 @@ def check(self, path, version): CheckVersion(_prefix_root(self.client.chroot, path), version) ) - def commit_async(self): + def commit_async(self) -> IAsyncResult: """Commit the transaction asynchronously. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -1772,7 +1966,7 @@ def commit_async(self): self.client._call(Transaction(self.operations), async_object) return async_object - def commit(self): + def commit(self) -> list[Any]: """Commit the transaction. :returns: A list of the results for each operation in the @@ -1781,19 +1975,23 @@ def commit(self): """ return self.commit_async().get() - def __enter__(self): + def __enter__(self) -> TransactionRequest: return self - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: """Commit and cleanup accumulated transaction data.""" if not exc_type: self.commit() - def _check_tx_state(self): + def _check_tx_state(self) -> None: if self.committed: raise ValueError("Transaction already committed") - def _add(self, request, post_processor=None): + def _add( + self, + request: Any, + post_processor: Optional[Callable[[Any], Any]] = None, + ) -> None: self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) self.operations.append(request) diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index b24c697cb..2bfa7d66e 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -1,6 +1,9 @@ """Kazoo Exceptions""" +from __future__ import annotations + from collections import defaultdict +from typing import Any, Callable, Type class KazooException(Exception): @@ -51,15 +54,15 @@ class SASLException(KazooException): """ -def _invalid_error_code(): +def _invalid_error_code() -> Any: raise RuntimeError("Invalid error code") -EXCEPTIONS = defaultdict(_invalid_error_code) +EXCEPTIONS: defaultdict = defaultdict(_invalid_error_code) -def _zookeeper_exception(code): - def decorator(klass): +def _zookeeper_exception(code: int) -> Callable[[Type[Any]], Type[Any]]: + def decorator(klass: Type[Any]) -> Type[Any]: EXCEPTIONS[code] = klass klass.code = code return klass diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 8869cc570..028a18bfd 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -1,10 +1,14 @@ """A eventlet based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import contextlib import logging +from typing import Any, Generator, TYPE_CHECKING + import eventlet from eventlet.green import socket as green_socket from eventlet.green import time as green_time @@ -15,6 +19,10 @@ from kazoo.handlers import utils from kazoo.handlers.utils import selector_select +if TYPE_CHECKING: + from kazoo.interfaces import Socket + + LOG = logging.getLogger(__name__) # sentinel objects @@ -22,7 +30,7 @@ @contextlib.contextmanager -def _yield_before_after(): +def _yield_before_after() -> Generator[None, None, None]: # Yield to any other co-routines... # # See: http://eventlet.net/doc/modules/greenthread.html @@ -42,7 +50,7 @@ class TimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: Any): super(AsyncResult, self).__init__( handler, green_threading.Condition, TimeoutError ) @@ -81,24 +89,26 @@ class SequentialEventletHandler(object): queue_impl = green_queue.LightQueue queue_empty = green_queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialEventletHandler` instance""" self.callback_queue = self.queue_impl() self.completion_queue = self.queue_impl() - self._workers = [] + self._workers: list[ + tuple[eventlet.GreenThread, green_queue.LightQueue] + ] = [] self._started = False @staticmethod - def sleep_func(wait): + def sleep_func(wait: float) -> None: green_time.sleep(wait) @property - def running(self): + def running(self) -> bool: return self._started timeout_exception = TimeoutError - def _process_completion_queue(self): + def _process_completion_queue(self) -> None: while True: cb = self.completion_queue.get() if cb is _STOP: @@ -114,7 +124,7 @@ def _process_completion_queue(self): finally: del cb # release before possible idle - def _process_callback_queue(self): + def _process_callback_queue(self) -> None: while True: cb = self.callback_queue.get() if cb is _STOP: @@ -130,7 +140,7 @@ def _process_callback_queue(self): finally: del cb # release before possible idle - def start(self): + def start(self) -> None: if not self._started: # Spawn our worker threads, we have # - A callback worker for watch events to be called @@ -142,7 +152,7 @@ def start(self): self._started = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: while self._workers: w, q = self._workers.pop() q.put(_STOP) @@ -150,38 +160,46 @@ def stop(self): self._started = False atexit.unregister(self.stop) - def socket(self, *args, **kwargs): + def socket(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_socket(green_socket) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(green_socket) - def event_object(self): + def event_object(self) -> green_threading.Event: return green_threading.Event() - def lock_object(self): + def lock_object(self) -> green_threading.Lock: return green_threading.Lock() - def rlock_object(self): + def rlock_object(self) -> green_threading.RLock: return green_threading.RLock() - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[list[int], list[int], list[int]]: with _yield_before_after(): + # Following appears to be a bug in mypy (see + # https://github.com/python/mypy/issues/6799) return selector_select( - *args, selectors_module=green_selectors, **kwargs + *args, + selectors_module=green_selectors, # type: ignore[misc] + **kwargs, ) - def async_result(self): + def async_result(self) -> AsyncResult: return AsyncResult(self) - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> green_threading.Thread: t = green_threading.Thread(target=func, args=args, kwargs=kwargs) t.daemon = True t.start() return t - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Any) -> None: self.callback_queue.put(lambda: callback.func(*callback.args)) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index f36389aac..ebc04b2e5 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -1,9 +1,13 @@ """A gevent based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import logging +from typing import Any, TYPE_CHECKING + import gevent from gevent import socket import gevent.event @@ -18,6 +22,10 @@ from kazoo.handlers import utils +if TYPE_CHECKING: + from gevent import Greenlet + from kazoo.interfaces import Socket + _using_libevent = gevent.__version__.startswith("0.") log = logging.getLogger(__name__) @@ -53,24 +61,24 @@ class SequentialGeventHandler(object): queue_empty = gevent.queue.Empty sleep_func = staticmethod(gevent.sleep) - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialGeventHandler` instance""" - self.callback_queue = self.queue_impl() + self.callback_queue: gevent.queue.Queue = self.queue_impl() self._running = False self._async = None self._state_change = Semaphore() - self._workers = [] + self._workers: list[Greenlet] = [] @property - def running(self): + def running(self) -> bool: return self._running class timeout_exception(gevent.Timeout): - def __init__(self, msg): + def __init__(self, msg: Any): gevent.Timeout.__init__(self, exception=msg) - def _create_greenlet_worker(self, queue): - def greenlet_worker(): + def _create_greenlet_worker(self, queue: Any) -> gevent.Greenlet: + def greenlet_worker() -> None: while True: try: func = queue.get() @@ -88,7 +96,7 @@ def greenlet_worker(): return gevent.spawn(greenlet_worker) - def start(self): + def start(self) -> None: """Start the greenlet workers.""" with self._state_change: if self._running: @@ -103,7 +111,7 @@ def start(self): self._workers.append(w) atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the greenlet workers and empty all queues.""" with self._state_change: if not self._running: @@ -123,33 +131,37 @@ def stop(self): atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any) -> tuple: return selector_select( - *args, selectors_module=gevent.selectors, **kwargs + # Likely a bug in mypy (see + # https://github.com/python/mypy/issues/6799) + *args, + selectors_module=gevent.selectors, + **kwargs, # type: ignore[misc] ) - def socket(self, *args, **kwargs): + def socket(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> gevent.event.Event: """Create an appropriate Event object""" return gevent.event.Event() - def lock_object(self): + def lock_object(self) -> Any: """Create an appropriate Lock object""" return gevent.thread.allocate_lock() - def rlock_object(self): + def rlock_object(self) -> RLock: """Create an appropriate RLock object""" return RLock() - def async_result(self): + def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance The :class:`AsyncResult` instance will have its completion @@ -160,11 +172,11 @@ def async_result(self): """ return AsyncResult() - def spawn(self, func, *args, **kwargs): + def spawn(self, func: Any, *args: Any, **kwargs: Any) -> gevent.Greenlet: """Spawn a function to run asynchronously""" return gevent.spawn(func, *args, **kwargs) - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Any) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index b9acd8756..f112022f3 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -10,6 +10,8 @@ :class:`~kazoo.handlers.gevent.SequentialGeventHandler` instead. """ + +from __future__ import annotations from __future__ import absolute_import import atexit @@ -19,9 +21,14 @@ import threading import time +from typing import Any, TYPE_CHECKING + from kazoo.handlers import utils from kazoo.handlers.utils import selector_select +from kazoo.interfaces import IHandler +if TYPE_CHECKING: + from kazoo.interfaces import Socket, SpawnedFunc # sentinel objects _STOP = object() @@ -29,7 +36,7 @@ log = logging.getLogger(__name__) -def _to_fileno(obj): +def _to_fileno(obj: Any) -> int: if isinstance(obj, int): fd = int(obj) elif hasattr(obj, "fileno"): @@ -55,13 +62,13 @@ class KazooTimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: Any) -> None: super(AsyncResult, self).__init__( handler, threading.Condition, KazooTimeoutError ) -class SequentialThreadingHandler(object): +class SequentialThreadingHandler(IHandler): """Threading handler for sequentially executing callbacks. This handler executes callbacks in a sequential manner. A queue is @@ -96,20 +103,22 @@ class SequentialThreadingHandler(object): queue_impl = queue.Queue queue_empty = queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialThreadingHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() + self.callback_queue: queue.Queue = self.queue_impl() + self.completion_queue: queue.Queue = self.queue_impl() self._running = False self._state_change = threading.Lock() - self._workers = [] + self._workers: list[threading.Thread] = [] @property - def running(self): + def running(self) -> bool: return self._running - def _create_thread_worker(self, work_queue): - def _thread_worker(): # pragma: nocover + def _create_thread_worker( + self, work_queue: queue.Queue + ) -> threading.Thread: + def _thread_worker() -> None: # pragma: nocover while True: try: func = work_queue.get() @@ -128,7 +137,7 @@ def _thread_worker(): # pragma: nocover t = self.spawn(_thread_worker) return t - def start(self): + def start(self) -> None: """Start the worker threads.""" with self._state_change: if self._running: @@ -143,7 +152,7 @@ def start(self): self._running = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the worker threads and empty all queues.""" with self._state_change: if not self._running: @@ -164,41 +173,43 @@ def stop(self): self.completion_queue = self.queue_impl() atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any) -> tuple: return selector_select(*args, **kwargs) - def socket(self): + def socket(self) -> Socket: return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> threading.Event: """Create an appropriate Event object""" return threading.Event() - def lock_object(self): + def lock_object(self) -> threading.Lock: """Create a lock object""" return threading.Lock() - def rlock_object(self): + def rlock_object(self) -> threading.RLock: """Create an appropriate RLock object""" return threading.RLock() - def async_result(self): + def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance""" return AsyncResult(self) - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> threading.Thread: t = threading.Thread(target=func, args=args, kwargs=kwargs) t.daemon = True t.start() return t - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Any) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 206806f6a..687730273 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -1,5 +1,7 @@ """Kazoo handler helpers""" +from __future__ import annotations + from collections import defaultdict import errno import functools @@ -8,6 +10,13 @@ import ssl import socket import time +from types import ModuleType +from typing import Any, Callable, Optional, Union, TYPE_CHECKING + +from kazoo.interfaces import IAsyncResult + +if TYPE_CHECKING: + from kazoo.interfaces import Socket HAS_FNCTL = True try: @@ -15,36 +24,51 @@ except ImportError: # pragma: nocover HAS_FNCTL = False + # sentinel objects +# Note: This needs to be a unique object that is not None, as None is used to +# indicate a successful result in AsyncResult. +# This should probably be an Enum, it would certainly be cleaner, but don't +# want to change the code too much. _NONE = object() +CallbackFunc = Callable[..., None] + -class AsyncResult(object): +class AsyncResult(IAsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler, condition_factory, timeout_factory): + def __init__( + self, + handler: Any, + condition_factory: Callable[[], Any], + timeout_factory: Callable[[], Any], + ) -> None: self._handler = handler - self._exception = _NONE + self._exception: Union[object, None, Exception] = _NONE self._condition = condition_factory() - self._callbacks = [] + self._callbacks: list[CallbackFunc] = [] self._timeout_factory = timeout_factory self.value = None - def ready(self): + def ready(self) -> bool: """Return true if and only if it holds a value or an exception""" return self._exception is not _NONE - def successful(self): + def successful(self) -> bool: """Return true if and only if it is ready and holds a value""" return self._exception is None @property - def exception(self): + def exception(self) -> Optional[Exception]: if self._exception is not _NONE: - return self._exception + # The next line should have return-value, but hound ci + # is frankly nothing but a hound dog + return self._exception # type: ignore + return None - def set(self, value=None): + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters.""" with self._condition: self.value = value @@ -52,14 +76,14 @@ def set(self, value=None): self._do_callbacks() self._condition.notify_all() - def set_exception(self, exception): + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters.""" with self._condition: self._exception = exception self._do_callbacks() self._condition.notify_all() - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: """Return the stored value or raise the exception. If there is no value raises TimeoutError. @@ -69,18 +93,18 @@ def get(self, block=True, timeout=None): if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] elif block: self._condition.wait(timeout) if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] # if we get to this point we timeout raise self._timeout_factory() - def get_nowait(self): + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raises TimeoutError @@ -88,14 +112,14 @@ def get_nowait(self): """ return self.get(block=False) - def wait(self, timeout=None): + def wait(self, timeout: Optional[float] = None) -> bool: """Block until the instance is ready.""" with self._condition: if not self.ready(): self._condition.wait(timeout) return self._exception is not _NONE - def rawlink(self, callback): + def rawlink(self, callback: CallbackFunc) -> None: """Register a callback to call when a value or an exception is set""" with self._condition: @@ -106,7 +130,7 @@ def rawlink(self, callback): if self.ready(): self._do_callbacks() - def unlink(self, callback): + def unlink(self, callback: CallbackFunc) -> None: """Remove the callback set by :meth:`rawlink`""" with self._condition: if self.ready(): @@ -116,7 +140,7 @@ def unlink(self, callback): if callback in self._callbacks: self._callbacks.remove(callback) - def _do_callbacks(self): + def _do_callbacks(self) -> None: """Execute the callbacks that were registered by :meth:`rawlink`. If the handler is in running state this method only schedules the calls to be performed by the handler. If it's stopped, @@ -131,19 +155,21 @@ def _do_callbacks(self): functools.partial(callback, self)() -def _set_fd_cloexec(fd): +def _set_fd_cloexec(fd: Socket) -> None: flags = fcntl.fcntl(fd, fcntl.F_GETFD) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) -def _set_default_tcpsock_options(module, sock): +def _set_default_tcpsock_options(module: ModuleType, sock: Socket) -> Socket: sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1) if HAS_FNCTL: _set_fd_cloexec(sock) return sock -def create_socket_pair(module, port=0): +def create_socket_pair( + module: ModuleType, port: int = 0 +) -> tuple[Socket, Socket]: """Create socket pair. If socket.socketpair isn't available, we emulate it. @@ -182,7 +208,7 @@ def create_socket_pair(module, port=0): return client_sock, srv_sock -def create_tcp_socket(module): +def create_tcp_socket(module: ModuleType) -> Socket: """Create a TCP socket with the CLOEXEC flag set.""" type_ = module.SOCK_STREAM if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover @@ -194,20 +220,20 @@ def create_tcp_socket(module): def create_tcp_connection( - module, - address, - hostname=None, - timeout=None, - use_ssl=False, - ca=None, - certfile=None, - keyfile=None, - keyfile_password=None, - verify_certs=True, - check_hostname=False, - options=None, - ciphers=None, -): + module: ModuleType, + address: Any, + hostname: Optional[str] = None, + timeout: Optional[float] = None, + use_ssl: bool = False, + ca: Optional[str] = None, + certfile: Optional[str] = None, + keyfile: Optional[str] = None, + keyfile_password: Optional[str] = None, + verify_certs: bool = True, + check_hostname: bool = False, + options: Optional[ssl.Options] = None, + ciphers: Optional[str] = None, +) -> Socket: end = None if timeout is None: # thanks to create_connection() developers for @@ -215,7 +241,7 @@ def create_tcp_connection( timeout = module.getdefaulttimeout() if timeout is not None: end = time.monotonic() + timeout - sock = None + sock: Optional[Socket] = None while True: timeout_at = end if end is None else end - time.monotonic() @@ -279,7 +305,13 @@ def create_tcp_connection( sock = module.create_connection(address, timeout_at) break except Exception as ex: - errnum = ex.errno if isinstance(ex, OSError) else ex[0] + # Seriously WTF? if ex is an exception, how can it be a tuple? + # I guess gevent can do this, but really... + errnum = ( + ex.errno + if isinstance(ex, OSError) + else ex[0] # type: ignore + ) if errnum == errno.EINTR: continue raise @@ -291,7 +323,9 @@ def create_tcp_connection( return sock -def capture_exceptions(async_result): +def capture_exceptions( + async_result: IAsyncResult, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -299,9 +333,9 @@ def capture_exceptions(async_result): """ - def capture(function): + def capture(function: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(function) - def captured_function(*args, **kwargs): + def captured_function(*args: Any, **kwargs: Any) -> Any: try: return function(*args, **kwargs) except Exception as exc: @@ -312,7 +346,9 @@ def captured_function(*args, **kwargs): return capture -def wrap(async_result): +def wrap( + async_result: IAsyncResult, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a non-None return value. @@ -321,9 +357,9 @@ def wrap(async_result): """ - def capture(function): + def capture(function: Callable[..., Any]) -> Callable[..., Any]: @capture_exceptions(async_result) - def captured_function(*args, **kwargs): + def captured_function(*args: Any, **kwargs: Any) -> Any: value = function(*args, **kwargs) if value is not None: async_result.set(value) @@ -334,7 +370,7 @@ def captured_function(*args, **kwargs): return capture -def fileobj_to_fd(fileobj): +def fileobj_to_fd(fileobj: Any) -> int: """Return a file descriptor from a file object. Parameters: @@ -359,8 +395,12 @@ def fileobj_to_fd(fileobj): def selector_select( - rlist, wlist, xlist, timeout=None, selectors_module=selectors -): + rlist: list[Any], + wlist: list[Any], + xlist: list[Any], + timeout: Optional[float] = None, + selectors_module: ModuleType = selectors, +) -> tuple[list[int], list[int], list[int]]: """Selector-based drop-in replacement for select to overcome select limitation on a maximum filehandle value. """ @@ -374,8 +414,8 @@ def selector_select( selectors_module.EVENT_READ: rlist, selectors_module.EVENT_WRITE: wlist, } - fd_events = defaultdict(int) - fd_fileobjs = defaultdict(list) + fd_events: defaultdict[int, int] = defaultdict(int) + fd_fileobjs: defaultdict[int, list[int]] = defaultdict(list) for event, fileobjs in events_mapping.items(): for fileobj in fileobjs: @@ -391,7 +431,9 @@ def selector_select( # gevent can raise OSError raise ValueError("Invalid event mask or fd") from e - revents, wevents, xevents = [], [], [] + revents: list[int] = [] + wevents: list[int] = [] + xevents: list[int] = [] try: ready = selector.select(timeout) finally: diff --git a/kazoo/hosts.py b/kazoo/hosts.py index 3ece93180..34e2c3a15 100644 --- a/kazoo/hosts.py +++ b/kazoo/hosts.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import urllib.parse +from typing import Optional, Union -def collect_hosts(hosts): +def collect_hosts( + hosts: Union[str, list[str]], +) -> tuple[list[tuple[str, int]], Optional[str]]: """ Collect a set of hosts and an optional chroot from a string or a list of strings. @@ -12,8 +17,8 @@ def collect_hosts(hosts): else: host_ports, chroot = hosts, None else: - host_ports, chroot = hosts.partition("/")[::2] - host_ports = host_ports.split(",") + host_ports_1, chroot = hosts.partition("/")[::2] + host_ports = host_ports_1.split(",") chroot = "/" + chroot if chroot else None result = [] diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 351f1fd89..9e688d708 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -8,10 +8,54 @@ """ +from __future__ import annotations + +import abc +import queue + +from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING + +if TYPE_CHECKING: + from kazoo.protocol.states import Callback + # public API -class IHandler(object): +class Socket(Protocol): + """This is for things that provide a socket.socket-like interface. + + This is required because: + 1. The socket in gevent doesn't inherit from socket.socket + 2. mypy gets confused if you have a method called socket and + subsequently attempt to use socket or socket.socket as a return type + """ + + def close(self) -> None: + ... + + def fileno(self) -> int: + ... + + def recv(self, bufsize: int, flags: int = 0) -> bytes: + ... + + def send(self, data: bytes | memoryview, flags: int = 0) -> int: + ... + + def sendall(self, data: bytes, flags: int = 0) -> None: + ... + + def setblocking(self, flags: bool) -> None: + ... + + def setsockopt(self, level: int, optname: int, value: int) -> None: + ... + + +SpawnedFunc = Callable[..., None] + + +class IHandler(abc.ABC): """A Callback Handler for Zookeeper completion and watch callbacks. This object must implement several methods responsible for @@ -44,43 +88,66 @@ class IHandler(object): """ - def start(self): + timeout_exception: type[Exception] = None # type: ignore[assignment] + sleep_func: staticmethod[[float], None] = None # type: ignore[assignment] + queue_impl: type[queue.Queue] = None # type: ignore[assignment] + + @abc.abstractmethod + def start(self) -> None: """Start the handler, used for setting up the handler.""" - def stop(self): + @abc.abstractmethod + def stop(self) -> None: """Stop the handler. Should block until the handler is safely stopped.""" - def select(self): + @abc.abstractmethod + def select( + self, + rlist: list, + wlist: list, + xlist: list, + timeout: Optional[float] = None, + ) -> tuple[list, list, list]: """A select method that implements Python's select.select API""" - def socket(self): - """A socket method that implements Python's socket.socket + @abc.abstractmethod + def socket(self) -> Socket: + """A socket method that implements Python's socket.socket API""" + + @abc.abstractmethod + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + """A socket method that implements Python's socket.create_connection API""" - def create_connection(self): - """A socket method that implements Python's - socket.create_connection API""" + @abc.abstractmethod + def create_socket_pair(self) -> tuple[Socket, Socket]: + """A socket method that implements Python's socket.socketpair API""" - def event_object(self): + @abc.abstractmethod + def event_object(self) -> Any: """Return an appropriate object that implements Python's threading.Event API""" - def lock_object(self): + @abc.abstractmethod + def lock_object(self) -> Any: """Return an appropriate object that implements Python's threading.Lock API""" - def rlock_object(self): + @abc.abstractmethod + def rlock_object(self) -> Any: """Return an appropriate object that implements Python's threading.RLock API""" - def async_result(self): + @abc.abstractmethod + def async_result(self) -> IAsyncResult: """Return an instance that conforms to the :class:`~IAsyncResult` interface appropriate for this handler""" - def spawn(self, func, *args, **kwargs): + @abc.abstractmethod + def spawn(self, func: SpawnedFunc, *args: Any, **kwargs: Any) -> Any: """Spawn a function to run asynchronously :param args: args to call the function with. @@ -91,7 +158,8 @@ def spawn(self, func, *args, **kwargs): """ - def dispatch_callback(self, callback): + @abc.abstractmethod + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object :param callback: A :class:`~kazoo.protocol.states.Callback` @@ -100,7 +168,7 @@ def dispatch_callback(self, callback): """ -class IAsyncResult(object): +class IAsyncResult(abc.ABC): """An Async Result object that can be queried for a value that has been set asynchronously. @@ -123,15 +191,18 @@ class IAsyncResult(object): """ - def ready(self): + @abc.abstractmethod + def ready(self) -> bool: """Return `True` if and only if it holds a value or an exception""" - def successful(self): + @abc.abstractmethod + def successful(self) -> bool: """Return `True` if and only if it is ready and holds a value""" - def set(self, value=None): + @abc.abstractmethod + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters. :param value: Value to store as the result. @@ -140,7 +211,8 @@ def set(self, value=None): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def set_exception(self, exception): + @abc.abstractmethod + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters. :param exception: Exception to raise when fetching the value. @@ -149,7 +221,8 @@ def set_exception(self, exception): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def get(self, block=True, timeout=None): + @abc.abstractmethod + def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: """Return the stored value or raise the exception :param block: Whether this method should block or return @@ -164,13 +237,15 @@ def get(self, block=True, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def get_nowait(self): + @abc.abstractmethod + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raise the Timeout exception class on the associated :class:`IHandler` interface.""" - def wait(self, timeout=None): + @abc.abstractmethod + def wait(self, timeout: Optional[float] = None) -> Any: """Block until the instance is ready. :param timeout: How long to wait for a value when `block` is @@ -182,7 +257,8 @@ def wait(self, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def rawlink(self, callback): + @abc.abstractmethod + def rawlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """Register a callback to call when a value or an exception is set @@ -194,10 +270,17 @@ def rawlink(self, callback): """ - def unlink(self, callback): + @abc.abstractmethod + def unlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """Remove the callback set by :meth:`rawlink` :param callback: A callback function to remove. :type callback: func """ + + @property + @abc.abstractmethod + def exception(self) -> Optional[Exception]: + """The exception set by :meth:`set_exception` or `None` if no + exception has been set""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 3df7b1626..0c3ae30f4 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -1,4 +1,7 @@ """Zookeeper Protocol Connection Handler""" + +from __future__ import annotations + from binascii import hexlify from contextlib import contextmanager import copy @@ -8,6 +11,7 @@ import socket import ssl import time +from typing import Any, Iterator, Literal, Optional, Union, TYPE_CHECKING from kazoo.exceptions import ( AuthFailedError, @@ -44,6 +48,10 @@ RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import Socket + try: import puresasl import puresasl.client @@ -76,7 +84,7 @@ # removed from Python3+ -def buffer(obj, offset=0): +def buffer(obj: Any, offset: int = 0) -> memoryview: return memoryview(obj)[offset:] @@ -94,22 +102,30 @@ class RWPinger(object): """ - def __init__(self, hosts, connection_func, socket_handling): + def __init__( + self, + hosts: Any, + connection_func: Any, + socket_handling: Any, + ): self.hosts = hosts self.connection = connection_func - self.last_attempt = None + self.last_attempt: Optional[float] = None self.socket_handling = socket_handling - def __iter__(self): + def __iter__(self) -> Iterator[Union[tuple, None, bool]]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay): + def _next_server(self, delay: float) -> Union[tuple, None, bool]: jitter = random.randint(0, 100) / 100.0 - while time.monotonic() < self.last_attempt + delay + jitter: + while ( + time.monotonic() + < self.last_attempt + delay + jitter # type: ignore[operator] + ): # Skip rw ping checks if its too soon return False for host, port in self.hosts: @@ -128,10 +144,17 @@ def _next_server(self, delay): except ConnectionDropped: return False + # NOTE: This does actually look like it's unreachable but I don't + # want to alter the code any more than necessary for the first + # pass. See https://github.com/python-zk/kazoo/issues/772 + # The loop is basically a sleep with jitter that can be # Add some jitter between host pings - while time.monotonic() < self.last_attempt + jitter: + while ( # type: ignore[unreachable] + time.monotonic() < self.last_attempt + jitter + ): return False delay *= 2 + return None class RWServerAvailable(Exception): @@ -141,7 +164,13 @@ class RWServerAvailable(Exception): class ConnectionHandler(object): """Zookeeper connection handler""" - def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): + def __init__( + self, + client: KazooClient, + retry_sleeper: Any, + logger: Optional[logging.Logger] = None, + sasl_options: Optional[dict] = None, + ): self.client = client self.handler = client.handler self.retry_sleeper = retry_sleeper @@ -154,15 +183,15 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.connection_stopped.set() self.ping_outstanding = client.handler.event_object() - self._read_sock = None - self._write_sock = None + self._read_sock: Optional[Socket] = None + self._write_sock: Optional[Socket] = None - self._socket = None - self._xid = None - self._rw_server = None - self._ro_mode = False + self._socket: Optional[Socket] = None + self._xid: Optional[int] = None + self._rw_server: Optional[tuple] = None + self._ro_mode: Optional[Union[Literal[False], Iterator]] = False - self._connection_routine = None + self._connection_routine: Optional[Any] = None self.sasl_options = sasl_options self.sasl_cli = None @@ -170,14 +199,14 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): # This is instance specific to avoid odd thread bug issues in Python # during shutdown global cleanup @contextmanager - def _socket_error_handling(self): + def _socket_error_handling(self) -> Any: try: yield except (socket.error, select.error) as e: err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) - def start(self): + def start(self) -> None: """Start the connection up""" if self.connection_closed.is_set(): rw_sockets = self.handler.create_socket_pair() @@ -189,7 +218,7 @@ def start(self): ) self._connection_routine = self.handler.spawn(self.zk_loop) - def stop(self, timeout=None): + def stop(self, timeout: Optional[float] = None) -> bool: """Ensure the writer has stopped, wait to see if it does.""" self.connection_stopped.wait(timeout) if self._connection_routine: @@ -197,7 +226,7 @@ def stop(self, timeout=None): self._connection_routine = None return self.connection_stopped.is_set() - def close(self): + def close(self) -> None: """Release resources held by the connection The connection can be restarted afterwards. @@ -212,7 +241,7 @@ def close(self): if rs is not None: rs.close() - def _server_pinger(self): + def _server_pinger(self) -> RWPinger: """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" return RWPinger( @@ -221,16 +250,19 @@ def _server_pinger(self): self._socket_error_handling, ) - def _read_header(self, timeout): + def _read_header(self, timeout: Optional[float]) -> tuple: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) header, offset = ReplyHeader.deserialize(b, 0) return header, b, offset - def _read(self, length, timeout): + def _read(self, length: int, timeout: Optional[float]) -> bytes: msgparts = [] remaining = length + # We know that self._socket is not None here because we only call + # this method when we have set up the connection and the read socket. + # But mypy doesn't understand that. with self._socket_error_handling(): while remaining > 0: # Because of SSL framing, a select may not return when using @@ -240,7 +272,7 @@ def _read(self, length, timeout): # data from the underlying socket. if ( hasattr(self._socket, "pending") - and self._socket.pending() > 0 + and self._socket.pending() > 0 # type: ignore[union-attr] ): pass else: @@ -252,7 +284,9 @@ def _read(self, length, timeout): "socket time-out during read" ) try: - chunk = self._socket.recv(remaining) + chunk = self._socket.recv( # type: ignore[union-attr] + remaining + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -267,7 +301,12 @@ def _read(self, length, timeout): remaining -= len(chunk) return b"".join(msgparts) - def _invoke(self, timeout, request, xid=None): + def _invoke( + self, + timeout: Optional[float], + request: Any, + xid: Optional[int] = None, + ) -> Any: """A special writer used during connection establishment only""" self._submit(request, timeout, xid) @@ -311,7 +350,12 @@ def _invoke(self, timeout, request, xid=None): return zxid - def _submit(self, request, timeout, xid=None): + def _submit( + self, + request: Any, + timeout: Optional[float], + xid: Optional[int] = None, + ) -> None: """Submit a request object with a timeout value and optional xid""" b = bytearray() @@ -328,7 +372,7 @@ def _submit(self, request, timeout, xid=None): ) self._write(int_struct.pack(len(b)) + b, timeout) - def _write(self, msg, timeout): + def _write(self, msg: bytes, timeout: Optional[float]) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) @@ -343,7 +387,9 @@ def _write(self, msg, timeout): ) msg_slice = buffer(msg, sent) try: - bytes_sent = self._socket.send(msg_slice) + bytes_sent = self._socket.send( # type:ignore[union-attr] + msg_slice + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -356,14 +402,14 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent - def _read_watch_event(self, buffer, offset): + def _read_watch_event(self, buffer: bytes, offset: int) -> None: client = self.client watch, offset = Watch.deserialize(buffer, offset) path = watch.path self.logger.debug("Received EVENT: %s", watch) - watchers = [] + watchers: list[WatchFunc] = [] if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) @@ -385,10 +431,15 @@ def _read_watch_event(self, buffer, offset): return # Dump the watchers to the watch thread - for watch in watchers: - client.handler.dispatch_callback(Callback("watch", watch, (ev,))) - - def _read_response(self, header, buffer, offset): + for watch1 in watchers: + client.handler.dispatch_callback(Callback("watch", watch1, (ev,))) + + def _read_response( + self, + header: Any, + buffer: bytes, + offset: int, + ) -> Optional[object]: client = self.client request, async_object, xid = client._pending.popleft() if header.zxid and header.zxid > 0: @@ -404,7 +455,11 @@ def _read_response(self, header, buffer, offset): # Determine if its an exists request and a no node error exists_error = ( - header.err == NoNodeError.code and request.type == Exists.type + # NoNodeError does actually have a code. It's added by a wrapper, + # which could possibly be better done via inheritance but this is + # less invasive to the existing code. + header.err == NoNodeError.code # type: ignore[attr-defined] + and request.type == Exists.type ) # Set the exception if its not an exists error @@ -430,7 +485,7 @@ def _read_response(self, header, buffer, offset): request, ) async_object.set_exception(exc) - return + return None self.logger.debug( "Received response(xid=%s): %r", xid, response ) @@ -452,8 +507,9 @@ def _read_response(self, header, buffer, offset): if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE + return None - def _read_socket(self, read_timeout): + def _read_socket(self, read_timeout: float) -> Optional[object]: """Called when there's something to read on the socket""" client = self.client @@ -476,8 +532,13 @@ def _read_socket(self, read_timeout): self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) + return None - def _send_request(self, read_timeout, connect_timeout): + def _send_request( + self, + read_timeout: float, + connect_timeout: float, + ) -> None: """Called when we have something to send out on the socket""" client = self.client try: @@ -489,7 +550,11 @@ def _send_request(self, read_timeout, connect_timeout): try: # Clear possible inconsistence (no request in the queue # but have data in the read socket), which causes cpu to spin. - self._read_sock.recv(1) + # + # We know _read_sock is not None because we only call this + # method when we have set up the connection and the read + # socket, but mypy doesn't understand that. + self._read_sock.recv(1) # type: ignore[union-attr] except OSError: pass return @@ -505,15 +570,19 @@ def _send_request(self, read_timeout, connect_timeout): if request.type == Auth.type: xid = AUTH_XID else: - self._xid = (self._xid % 2147483647) + 1 + # We must have initialised the xid counter by now + # Might want to consider initialising it to 0 instead of none? + self._xid = (self._xid % 2147483647) + 1 # type: ignore[operator] xid = self._xid self._submit(request, connect_timeout, xid) client._queue.popleft() - self._read_sock.recv(1) + # _read_sock should never be None here as we only call this method + # when we have set up the connection and the read socket. + self._read_sock.recv(1) # type: ignore[union-attr] client._pending.append((request, async_object, xid)) - def _send_ping(self, connect_timeout): + def _send_ping(self, connect_timeout: float) -> None: self.ping_outstanding.set() self._submit(PingInstance, connect_timeout, PING_XID) @@ -524,7 +593,7 @@ def _send_ping(self, connect_timeout): self._rw_server = result raise RWServerAvailable() - def zk_loop(self): + def zk_loop(self) -> None: """Main Zookeeper handling loop""" self.logger.log(BLATHER, "ZK loop started") @@ -546,7 +615,7 @@ def zk_loop(self): self.client._session_callback(KeeperState.CLOSED) self.logger.log(BLATHER, "Connection stopped") - def _expand_client_hosts(self): + def _expand_client_hosts(self) -> list: # Expand the entire list in advance so we can randomize it if needed host_ports = [] for host, port in self.client.hosts: @@ -564,7 +633,7 @@ def _expand_client_hosts(self): random.shuffle(host_ports) return host_ports - def _connect_loop(self, retry): + def _connect_loop(self, retry: Any) -> object: # Iterate through the hosts a full cycle before starting over status = None host_ports = self._expand_client_hosts() @@ -586,7 +655,13 @@ def _connect_loop(self, retry): else: raise ForceRetryError("Reconnecting") - def _connect_attempt(self, host, hostip, port, retry): + def _connect_attempt( + self, + host: str, + hostip: str, + port: int, + retry: Any, + ) -> object: client = self.client KazooTimeoutError = self.handler.timeout_exception @@ -674,9 +749,18 @@ def _connect_attempt(self, host, hostip, port, retry): raise finally: if self._socket is not None: - self._socket.close() - - def _connect(self, host, hostip, port): + # I think this is a bug in mypy, as the socket does get set up + # in self._connect, but it doesn't seem to be able to track + # that. + self._socket.close() # type: ignore[unreachable] + return None + + def _connect( + self, + host: str, + hostip: str, + port: int, + ) -> tuple[float, float]: client = self.client self.logger.info( "Connecting to %s(%s):%s, use_ssl: %r", @@ -707,7 +791,7 @@ def _connect(self, host, hostip, port): check_hostname=self.client.check_hostname, ) - self._socket.setblocking(0) + self._socket.setblocking(0) # type: ignore[arg-type] connect = Connect( 0, @@ -771,23 +855,36 @@ def _connect(self, host, hostip, port): return read_timeout, connect_timeout - def _authenticate_with_sasl(self, host, timeout): + def _authenticate_with_sasl(self, host: str, timeout: float) -> None: """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: raise SASLException("Missing SASL support") - if "service" not in self.sasl_options: - self.sasl_options["service"] = "zookeeper" + # Although this can only be called if sasl_options is not None, we + # really should just have make self.sasl_options into an empty dict + # in the constructor. However, I want to avoid code changes in as + # much as possible. + if "service" not in self.sasl_options: # type: ignore[operator] + self.sasl_options["service"] = "zookeeper" # type: ignore[index] # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options["mechanism"] == "DIGEST-MD5": + if ( + self.sasl_options["mechanism"] # type: ignore[index] + == "DIGEST-MD5" + ): host = "zk-sasl-md5" - sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, **self.sasl_options + # I don't think the client.sasl_cli attribute is actually used + # anywhere else, so not sure why we need to set it on the client, + # but again, I want to avoid code changes as much as possible. + sasl_cli = ( + self.client.sasl_cli # type: ignore[attr-defined] + ) = puresasl.client.SASLClient( + host=host, + **self.sasl_options, # type: ignore[arg-type] ) # Initialize the process with an empty challenge token diff --git a/kazoo/protocol/paths.py b/kazoo/protocol/paths.py index b8bf66507..7c47ce8a0 100644 --- a/kazoo/protocol/paths.py +++ b/kazoo/protocol/paths.py @@ -1,4 +1,4 @@ -def normpath(path, trailing=False): +def normpath(path: str, trailing: bool = False) -> str: """Normalize path, eliminating double slashes, etc.""" comps = path.split("/") new_comps = [] @@ -16,7 +16,7 @@ def normpath(path, trailing=False): return new_path -def join(a, *p): +def join(a: str, *p: str) -> str: """Join two or more pathname components, inserting '/' as needed. If any component is an absolute path, all previous path components @@ -34,23 +34,23 @@ def join(a, *p): return path -def isabs(s): +def isabs(s: str) -> bool: """Test whether a path is absolute""" return s.startswith("/") -def basename(p): +def basename(p: str) -> str: """Returns the final component of a pathname""" i = p.rfind("/") + 1 return p[i:] -def _prefix_root(root, path, trailing=False): +def _prefix_root(root: str, path: str, trailing: bool = False) -> str: """Prepend a root to a path.""" return normpath( join(_norm_root(root), path.lstrip("/")), trailing=trailing ) -def _norm_root(root): +def _norm_root(root: str) -> str: return normpath(join("/", root)) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c2..4311d4e42 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -1,12 +1,17 @@ """Zookeeper Serializers, Deserializers, and NamedTuple objects""" -from collections import namedtuple +from __future__ import annotations + import struct +from collections import namedtuple +from typing import Any, ClassVar, Optional, Sequence, Union, TYPE_CHECKING from kazoo.exceptions import EXCEPTIONS from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc # Struct objects with formats compiled bool_struct = struct.Struct("B") @@ -21,7 +26,7 @@ stat_struct = struct.Struct("!qqqqiiiqiiq") -def read_string(buffer, offset): +def read_string(buffer: bytes, offset: int) -> tuple: """Reads an int specified buffer into a string and returns the string and the new offset in the buffer""" length = int_struct.unpack_from(buffer, offset)[0] @@ -34,7 +39,7 @@ def read_string(buffer, offset): return buffer[index : index + length].decode("utf-8"), offset -def read_acl(bytes, offset): +def read_acl(bytes: bytes, offset: int) -> tuple: perms = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size scheme, offset = read_string(bytes, offset) @@ -42,7 +47,7 @@ def read_acl(bytes, offset): return ACL(perms, Id(scheme, id)), offset -def write_string(bytes): +def write_string(bytes: Optional[str]) -> bytes: if not bytes: return int_struct.pack(-1) else: @@ -50,14 +55,14 @@ def write_string(bytes): return int_struct.pack(len(utf8_str)) + utf8_str -def write_buffer(bytes): +def write_buffer(bytes: Optional[bytes]) -> bytes: if bytes is None: return int_struct.pack(-1) else: return int_struct.pack(len(bytes)) + bytes -def read_buffer(bytes, offset): +def read_buffer(bytes: bytes, offset: int) -> tuple: length = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if length < 0: @@ -69,10 +74,10 @@ def read_buffer(bytes, offset): class Close(namedtuple("Close", "")): - type = -11 + type: ClassVar[int] = -11 @classmethod - def serialize(cls): + def serialize(self) -> bytes: return b"" @@ -80,10 +85,10 @@ def serialize(cls): class Ping(namedtuple("Ping", "")): - type = 11 + type: ClassVar[int] = 11 @classmethod - def serialize(cls): + def serialize(cls) -> bytes: return b"" @@ -97,9 +102,16 @@ class Connect( " time_out session_id passwd read_only", ) ): - type = None + protocol_version: int + last_zxid_seen: int + time_out: int + session_id: int + passwd: bytes + read_only: bool - def serialize(self): + type: Optional[int] = None # Note: Not a classvar + + def serialize(self) -> bytearray: b = bytearray() b.extend( int_long_int_long_struct.pack( @@ -114,7 +126,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Any, int]: proto_version, timeout, session_id = int_int_long_struct.unpack_from( bytes, offset ) @@ -133,9 +145,14 @@ def deserialize(cls, bytes, offset): class Create(namedtuple("Create", "path data acl flags")): - type = 1 + path: str + data: Optional[bytes] + acl: Sequence[ACL] + flags: int + + type: ClassVar[int] = 1 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -150,59 +167,74 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> str: return read_string(bytes, offset)[0] class Delete(namedtuple("Delete", "path version")): - type = 2 + path: str + version: int + + type: ClassVar[int] = 2 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> bool: return True class Exists(namedtuple("Exists", "path watcher")): - type = 3 + path: str + watcher: Optional[WatchFunc] - def serialize(self): + type: ClassVar[int] = 3 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> Optional[ZnodeStat]: + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return stat if stat.czxid != -1 else None class GetData(namedtuple("GetData", "path watcher")): - type = 4 + path: str + watcher: Optional[WatchFunc] + + type: ClassVar[int] = 4 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[Optional[bytes], ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class SetData(namedtuple("SetData", "path data version")): - type = 5 + path: str + data: Optional[bytes] + version: int + + type: ClassVar[int] = 5 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -210,18 +242,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetACL(namedtuple("GetACL", "path")): - type = 6 + path: str - def serialize(self): + type: ClassVar[int] = 6 + + def serialize(self) -> bytearray: return bytearray(write_string(self.path)) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> Union[tuple[list[ACL], ZnodeStat], list[ACL]]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -231,14 +267,18 @@ def deserialize(cls, bytes, offset): for c in range(count): acl, offset = read_acl(bytes, offset) acls.append(acl) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return acls, stat class SetACL(namedtuple("SetACL", "path acls version")): - type = 7 + path: str + acls: Sequence[ACL] + version: int + + type: ClassVar[int] = 7 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(len(self.acls))) @@ -252,21 +292,24 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetChildren(namedtuple("GetChildren", "path watcher")): - type = 8 + path: str + watcher: Optional[WatchFunc] + + type: ClassVar[int] = 8 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -280,27 +323,34 @@ def deserialize(cls, bytes, offset): class Sync(namedtuple("Sync", "path")): - type = 9 + path: str - def serialize(self): + type: ClassVar[int] = 9 + + def serialize(self) -> bytes: return write_string(self.path) @classmethod - def deserialize(cls, buffer, offset): + def deserialize(cls, buffer: bytes, offset: int) -> str: return read_string(buffer, offset)[0] class GetChildren2(namedtuple("GetChildren2", "path watcher")): - type = 12 + path: str + watcher: Optional[WatchFunc] + + type: ClassVar[int] = 12 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> Union[tuple[list[str], ZnodeStat], list[str]]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -310,14 +360,17 @@ def deserialize(cls, bytes, offset): for c in range(count): child, offset = read_string(bytes, offset) children.append(child) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return children, stat class CheckVersion(namedtuple("CheckVersion", "path version")): - type = 13 + path: str + version: int + + type: ClassVar[int] = 13 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) @@ -325,9 +378,11 @@ def serialize(self): class Transaction(namedtuple("Transaction", "operations")): - type = 14 + operations: list[Any] - def serialize(self): + type: ClassVar[int] = 14 + + def serialize(self) -> bytearray: b = bytearray() for op in self.operations: b.extend( @@ -336,7 +391,7 @@ def serialize(self): return b + multiheader_struct.pack(-1, True, -1) @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> list[Any]: header = MultiHeader(None, False, None) results = [] response = None @@ -346,9 +401,7 @@ def deserialize(cls, bytes, offset): elif header.type == Delete.type: response = True elif header.type == SetData.type: - response = ZnodeStat._make( - stat_struct.unpack_from(bytes, offset) - ) + response = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) offset += stat_struct.size elif header.type == CheckVersion.type: response = True @@ -362,7 +415,7 @@ def deserialize(cls, bytes, offset): return results @staticmethod - def unchroot(client, response): + def unchroot(client: KazooClient, response: list[Any]) -> list[Any]: resp = [] for result in response: if isinstance(result, str): @@ -373,9 +426,14 @@ def unchroot(client, response): class Create2(namedtuple("Create2", "path data acl flags")): - type = 15 + path: str + data: Optional[bytes] + acl: Sequence[ACL] + flags: int + + type: ClassVar[int] = 15 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -390,18 +448,23 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[str, ZnodeStat]: path, offset = read_string(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return path, stat class Reconfig( namedtuple("Reconfig", "joining leaving new_members config_id") ): - type = 16 + joining: Optional[str] + leaving: Optional[str] + new_members: Optional[str] + config_id: int + + type: ClassVar[int] = 16 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.joining)) b.extend(write_string(self.leaving)) @@ -410,16 +473,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[Optional[bytes], ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class Auth(namedtuple("Auth", "auth_type scheme auth")): - type = 100 + auth_type: int + scheme: str + auth: str - def serialize(self): + type: ClassVar[int] = 100 + + def serialize(self) -> bytes: return ( int_struct.pack(self.auth_type) + write_string(self.scheme) @@ -428,22 +497,30 @@ def serialize(self): class SASL(namedtuple("SASL", "challenge")): - type = 102 + challenge: Optional[bytes] + + type: ClassVar[int] = 102 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_buffer(self.challenge)) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[Optional[bytes], int]: challenge, offset = read_buffer(bytes, offset) return challenge, offset class Watch(namedtuple("Watch", "type state path")): + type: int + state: int + path: str + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Watch, int]: """Given bytes and the current bytes offset, return the type, state, path, and new offset""" type, state = int_int_struct.unpack_from(bytes, offset) @@ -453,19 +530,27 @@ def deserialize(cls, bytes, offset): class ReplyHeader(namedtuple("ReplyHeader", "xid, zxid, err")): + xid: int + zxid: int + err: int + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[ReplyHeader, int]: """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size return ( - cls._make(reply_header_struct.unpack_from(bytes, offset)), + cls(*reply_header_struct.unpack_from(bytes, offset)), new_offset, ) -class MultiHeader(namedtuple("MultiHeader", "type done err")): - def serialize(self): +class MultiHeader(namedtuple("MultiHeader", "type, done, err")): + type: Optional[int] + done: bool + err: Optional[int] + + def serialize(self) -> bytearray: b = bytearray() b.extend(int_struct.pack(self.type)) b.extend([1 if self.done else 0]) @@ -473,7 +558,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[MultiHeader, int]: t, done, err = multiheader_struct.unpack_from(bytes, offset) offset += multiheader_struct.size return cls(t, done == 1, err), offset diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e8..5f2a23e60 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -1,8 +1,13 @@ """Kazoo State and Event objects""" -from collections import namedtuple +from __future__ import annotations -class KazooState(object): +from enum import Enum +from typing import Callable, NamedTuple, Optional + + +# This is a (str, Enum) for backwards compatibility. +class KazooState(str, Enum): """High level connection state values States inspired by Netflix Curator. @@ -33,7 +38,8 @@ class KazooState(object): LOST = "LOST" -class KeeperState(object): +# This is a (str, Enum) for backwards compatibility. +class KeeperState(str, Enum): """Zookeeper State Represents the Zookeeper state. Watch functions will receive a @@ -70,7 +76,8 @@ class KeeperState(object): EXPIRED_SESSION = "EXPIRED_SESSION" -class EventType(object): +# This is a (str, Enum) for backwards compatibility. +class EventType(str, Enum): """Zookeeper Event Represents a Zookeeper event. Events trigger watch functions which @@ -117,7 +124,7 @@ class EventType(object): } -class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): +class WatchedEvent(NamedTuple): """A change on ZooKeeper that a Watcher is able to respond to. The :class:`WatchedEvent` includes exactly what happened, the @@ -140,8 +147,12 @@ class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): """ + type: EventType + state: KeeperState + path: Optional[str] + -class Callback(namedtuple("Callback", ("type", "func", "args"))): +class Callback(NamedTuple): """A callback that is handed to a handler for dispatch :param type: Type of the callback, currently is only 'watch' @@ -150,15 +161,12 @@ class Callback(namedtuple("Callback", ("type", "func", "args"))): """ + type: str + func: Callable + args: tuple -class ZnodeStat( - namedtuple( - "ZnodeStat", - "czxid mzxid ctime mtime version" - " cversion aversion ephemeralOwner dataLength" - " numChildren pzxid", - ) -): + +class ZnodeStat(NamedTuple): """A ZnodeStat structure with convenience properties When getting the value of a znode from Zookeeper, the properties for @@ -216,38 +224,50 @@ class ZnodeStat( """ + czxid: int + mzxid: int + ctime: int + mtime: int + version: int + cversion: int + aversion: int + ephemeralOwner: int + dataLength: int + numChildren: int + pzxid: int + @property - def acl_version(self): + def acl_version(self) -> int: return self.aversion @property - def children_version(self): + def children_version(self) -> int: return self.cversion @property - def created(self): + def created(self) -> float: return self.ctime / 1000.0 @property - def last_modified(self): + def last_modified(self) -> float: return self.mtime / 1000.0 @property - def owner_session_id(self): + def owner_session_id(self) -> Optional[int]: return self.ephemeralOwner or None @property - def creation_transaction_id(self): + def creation_transaction_id(self) -> int: return self.czxid @property - def last_modified_transaction_id(self): + def last_modified_transaction_id(self) -> int: return self.mzxid @property - def data_length(self): + def data_length(self) -> int: return self.dataLength @property - def children_count(self): + def children_count(self) -> int: return self.numChildren diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index 683e807b0..9af59e1a5 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -4,13 +4,20 @@ :Status: Unknown """ + +from __future__ import annotations + import os import socket import uuid +from typing import Any, Literal, Optional, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError from kazoo.protocol.states import EventType +if TYPE_CHECKING: + from kazoo.client import KazooClient + class Barrier(object): """Kazoo Barrier @@ -27,7 +34,7 @@ class Barrier(object): """ - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """Create a Kazoo Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -37,11 +44,11 @@ def __init__(self, client, path): self.client = client self.path = path - def create(self): + def create(self) -> None: """Establish the barrier if it doesn't exist already""" self.client.retry(self.client.ensure_path, self.path) - def remove(self): + def remove(self) -> bool: """Remove the barrier :returns: Whether the barrier actually needed to be removed. @@ -54,7 +61,7 @@ def remove(self): except NoNodeError: return False - def wait(self, timeout=None): + def wait(self, timeout: Optional[float] = None) -> bool: """Wait on the barrier to be cleared :returns: True if the barrier has been cleared, otherwise @@ -64,7 +71,7 @@ def wait(self, timeout=None): """ cleared = self.client.handler.event_object() - def wait_for_clear(event): + def wait_for_clear(event: Any) -> None: if event.type == EventType.DELETED: cleared.set() @@ -93,7 +100,13 @@ class DoubleBarrier(object): """ - def __init__(self, client, path, num_clients, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + num_clients: int, + identifier: Optional[str] = None, + ): """Create a Double Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -118,7 +131,7 @@ def __init__(self, client, path, num_clients, identifier=None): self.node_name = uuid.uuid4().hex self.create_path = self.path + "/" + self.node_name - def enter(self): + def enter(self) -> None: """Enter the barrier, blocks until all nodes have entered""" try: self.client.retry(self._inner_enter) @@ -128,7 +141,7 @@ def enter(self): self._best_effort_cleanup() self.participating = False - def _inner_enter(self): + def _inner_enter(self) -> Literal[True]: # make sure our barrier parent node exists if not self.assured_path: self.client.ensure_path(self.path) @@ -145,7 +158,7 @@ def _inner_enter(self): except NodeExistsError: pass - def created(event): + def created(event: Any) -> None: if event.type == EventType.CREATED: ready.set() @@ -159,7 +172,7 @@ def created(event): self.client.ensure_path(self.path + "/ready") return True - def leave(self): + def leave(self) -> None: """Leave the barrier, blocks until all nodes have left""" try: self.client.retry(self._inner_leave) @@ -168,7 +181,7 @@ def leave(self): self._best_effort_cleanup() self.participating = False - def _inner_leave(self): + def _inner_leave(self) -> bool: # Delete the ready node if its around try: self.client.delete(self.path + "/ready") @@ -188,7 +201,7 @@ def _inner_leave(self): ready = self.client.handler.event_object() - def deleted(event): + def deleted(event: Any) -> None: if event.type == EventType.DELETED: ready.set() @@ -214,7 +227,7 @@ def deleted(event): # Wait for the lowest to be deleted ready.wait() - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.retry(self.client.delete, self.create_path) except NoNodeError: diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 0a22a6c7e..7bf33c0b5 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -10,17 +10,26 @@ See also: http://curator.apache.org/curator-recipes/tree-cache.html """ + +from __future__ import annotations from __future__ import absolute_import import contextlib import functools import logging import operator +from typing import Any, Callable, Generator, Optional, Protocol, TYPE_CHECKING + from kazoo.exceptions import NoNodeError, KazooException from kazoo.protocol.paths import _prefix_root, join as kazoo_join from kazoo.protocol.states import KazooState, EventType +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import IAsyncResult + from kazoo.protocol.states import WatchedEvent + logger = logging.getLogger(__name__) @@ -37,18 +46,18 @@ class TreeCache(object): _STOP = object() - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): self._client = client self._root = TreeNode.make_root(self, path) self._state = self.STATE_LATENT self._outstanding_ops = 0 self._is_initialized = False - self._error_listeners = [] - self._event_listeners = [] + self._error_listeners: list[Callable[[Exception], Any]] = [] + self._event_listeners: list[Callable[[TreeEvent], Any]] = [] self._task_queue = client.handler.queue_impl() self._task_thread = None - def start(self): + def start(self) -> None: """Starts the cache. The cache is not started automatically. You must call this method. @@ -85,7 +94,7 @@ def start(self): # without lock. self._in_background(self._root.on_created) - def close(self): + def close(self) -> None: """Closes the cache. A closed cache was detached from ZooKeeper's changes. And all nodes @@ -109,7 +118,9 @@ def close(self): # ZooKeeper actually. self._root.on_deleted() - def listen(self, listener): + def listen( + self, listener: Callable[[TreeEvent], Any] + ) -> Callable[[TreeEvent], Any]: """Registers a function to listen the cache events. The cache events are changes of local data. They are delivered from @@ -124,7 +135,9 @@ def listen(self, listener): self._event_listeners.append(listener) return listener - def listen_fault(self, listener): + def listen_fault( + self, listener: Callable[[Exception], Any] + ) -> Callable[[Exception], Any]: """Registers a function to listen the exceptions. It is possible to meet some exceptions during the cache running. You @@ -138,7 +151,9 @@ def listen_fault(self, listener): self._error_listeners.append(listener) return listener - def get_data(self, path, default=None): + def get_data( + self, path: str, default: Optional[NodeData] = None + ) -> Optional[NodeData]: """Gets data of a node from cache. :param path: The absolute path string. @@ -150,7 +165,9 @@ def get_data(self, path, default=None): node = self._find_node(path) return default if node is None else node._data - def get_children(self, path, default=None): + def get_children( + self, path: str, default: Optional[frozenset[str]] = None + ) -> Optional[frozenset[str]]: """Gets node children list from in-memory snapshot. :param path: The absolute path string. @@ -162,7 +179,7 @@ def get_children(self, path, default=None): node = self._find_node(path) return default if node is None else frozenset(node._children) - def _find_node(self, path): + def _find_node(self, path: str) -> Optional[TreeNode]: if not path.startswith(self._root._path): raise ValueError("outside of tree") striped_path = path[len(self._root._path) :].strip("/") @@ -170,25 +187,27 @@ def _find_node(self, path): current_node = self._root for node_name in splited_path: if node_name not in current_node._children: - return + return None current_node = current_node._children[node_name] return current_node - def _publish_event(self, event_type, event_data=None): + def _publish_event(self, event_type: int, event_data: Any = None) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: logger.debug("public event: %r", event) self._in_background(self._do_publish_event, event) - def _do_publish_event(self, event): + def _do_publish_event(self, event: TreeEvent) -> None: for listener in self._event_listeners: with handle_exception(self._error_listeners): listener(event) - def _in_background(self, func, *args, **kwargs): + def _in_background( + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: self._task_queue.put((func, args, kwargs)) - def _do_background(self): + def _do_background(self) -> None: while True: with handle_exception(self._error_listeners): cb = self._task_queue.get() @@ -200,7 +219,7 @@ def _do_background(self): # release before possible idle del cb, func, args, kwargs - def _session_watcher(self, state): + def _session_watcher(self, state: Any) -> None: if state == KazooState.SUSPENDED: self._publish_event(TreeEvent.CONNECTION_SUSPENDED) elif state == KazooState.CONNECTED: @@ -212,6 +231,11 @@ def _session_watcher(self, state): self._publish_event(TreeEvent.CONNECTION_LOST) +class AsyncWatcher(Protocol): + def __call__(self, path: str, watch: Optional[WatchFunc]) -> IAsyncResult: + ... + + class TreeNode(object): """The tree node record. @@ -234,28 +258,28 @@ class TreeNode(object): STATE_LIVE = 1 STATE_DEAD = 2 - def __init__(self, tree, path, parent): + def __init__(self, tree: TreeCache, path: str, parent: Optional[TreeNode]): self._tree = tree self._path = path self._parent = parent - self._depth = parent._depth + 1 if parent else 0 - self._children = {} + self._depth: int = parent._depth + 1 if parent is not None else 0 + self._children: dict[str, TreeNode] = {} self._state = self.STATE_PENDING - self._data = None + self._data: Optional[NodeData] = None @classmethod - def make_root(cls, tree, path): + def make_root(cls, tree: TreeCache, path: str) -> TreeNode: return cls(tree, path, None) - def on_reconnected(self): + def on_reconnected(self) -> None: self._refresh() for child in self._children.values(): child.on_reconnected() - def on_created(self): + def on_created(self) -> None: self._refresh() - def on_deleted(self): + def on_deleted(self) -> None: old_children, self._children = self._children, {} old_data, self._data = self._data, None @@ -278,37 +302,41 @@ def on_deleted(self): del self._parent._children[child] self._reset_watchers() - def _publish_event(self, *args, **kwargs): + def _publish_event(self, *args: Any, **kwargs: Any) -> Any: return self._tree._publish_event(*args, **kwargs) - def _reset_watchers(self): + def _reset_watchers(self) -> None: client = self._tree._client for _watchers in (client._data_watchers, client._child_watchers): _path = _prefix_root(client.chroot, self._path) _watcher = _watchers.get(_path, set()) _watcher.discard(self._process_watch) - def _refresh(self): + def _refresh(self) -> None: self._refresh_data() self._refresh_children() - def _refresh_data(self): + def _refresh_data(self) -> None: self._call_client("get", self._path) - def _refresh_children(self): + def _refresh_children(self) -> None: # TODO max-depth checking support self._call_client("get_children", self._path) - def _call_client(self, method_name, path): + def _call_client(self, method_name: str, path: str) -> None: assert method_name in ("get", "get_children", "exists") self._tree._outstanding_ops += 1 callback = functools.partial( self._tree._in_background, self._process_result, method_name, path ) - method = getattr(self._tree._client, method_name + "_async") + # The typing for this is really bad but the type checker can + # understand it with a few hacks + method: AsyncWatcher = getattr( + self._tree._client, method_name + "_async" + ) method(path, watch=self._process_watch).rawlink(callback) - def _process_watch(self, watched_event): + def _process_watch(self, watched_event: WatchedEvent) -> None: logger.debug("process_watch: %r", watched_event) with handle_exception(self._tree._error_listeners): if watched_event.type == EventType.CREATED: @@ -321,7 +349,9 @@ def _process_watch(self, watched_event): elif watched_event.type == EventType.CHILD: self._refresh_children() - def _process_result(self, method_name, path, result): + def _process_result( + self, method_name: str, path: str, result: Any + ) -> None: logger.debug("process_result: %s %s", method_name, path) if method_name == "exists": assert self._parent is None, "unexpected EXISTS on non-root" @@ -332,7 +362,7 @@ def _process_result(self, method_name, path, result): self.on_created() elif method_name == "get_children": if result.successful(): - children = result.get() + children: list[str] = result.get() for child in sorted(children): full_path = kazoo_join(path, child) if child not in self._children: @@ -385,7 +415,7 @@ class TreeEvent(tuple): event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type, event_data): + def make(cls, event_type: int, event_data: Any) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. @@ -415,7 +445,7 @@ class NodeData(tuple): stat = property(operator.itemgetter(2)) @classmethod - def make(cls, path, data, stat): + def make(cls, path: str, data: bytes, stat: Any) -> NodeData: """Creates a new NodeData tuple. :returns: A :class:`~kazoo.recipe.cache.NodeData` instance. @@ -424,7 +454,9 @@ def make(cls, path, data, stat): @contextlib.contextmanager -def handle_exception(listeners): +def handle_exception( + listeners: list[Callable[[Exception], Any]], +) -> Generator[None, None, None]: try: yield except Exception as e: diff --git a/kazoo/recipe/counter.py b/kazoo/recipe/counter.py index 3b2cc339c..1be44902d 100644 --- a/kazoo/recipe/counter.py +++ b/kazoo/recipe/counter.py @@ -4,9 +4,20 @@ :Status: Unknown """ + +from __future__ import annotations + +import struct +from typing import Optional, Union, TYPE_CHECKING + from kazoo.exceptions import BadVersionError from kazoo.retry import ForceRetryError -import struct + +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +Number = Union[int, float] class Counter(object): @@ -58,7 +69,13 @@ class Counter(object): """ - def __init__(self, client, path, default=0, support_curator=False): + def __init__( + self, + client: KazooClient, + path: str, + default: Number = 0, + support_curator: bool = False, + ): """Create a Kazoo Counter :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -74,22 +91,24 @@ def __init__(self, client, path, default=0, support_curator=False): self.default_type = type(default) self.support_curator = support_curator self._ensured_path = False - self.pre_value = None - self.post_value = None + self.pre_value: Optional[Number] = None + self.post_value: Optional[Number] = None if self.support_curator and not isinstance(self.default, int): raise TypeError( "when support_curator is enabled the default " "type must be an int" ) - def _ensure_node(self): + def _ensure_node(self) -> None: if not self._ensured_path: # make sure our node exists self.client.ensure_path(self.path) self._ensured_path = True - def _value(self): + def _value(self) -> tuple[Number, int]: self._ensure_node() + # This is astonishingly hard to follow... + old: Union[bytes, str, Number] old, stat = self.client.get(self.path) if self.support_curator: old = struct.unpack(">i", old)[0] if old != b"" else self.default @@ -100,16 +119,16 @@ def _value(self): return data, version @property - def value(self): + def value(self) -> Number: return self._value()[0] - def _change(self, value): + def _change(self, value: Number) -> Counter: if not isinstance(value, self.default_type): raise TypeError("invalid type for value change") self.client.retry(self._inner_change, value) return self - def _inner_change(self, value): + def _inner_change(self, value: Number) -> None: self.pre_value, version = self._value() post_value = self.pre_value + value if self.support_curator: @@ -123,10 +142,10 @@ def _inner_change(self, value): raise ForceRetryError() self.post_value = post_value - def __add__(self, value): + def __add__(self, value: Number) -> Counter: """Add value to counter.""" return self._change(value) - def __sub__(self, value): + def __sub__(self, value: Number) -> Counter: """Subtract value from counter.""" return self._change(-value) diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 93bb72580..82a32cd3c 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -4,8 +4,16 @@ :Status: Unknown """ + +from __future__ import annotations + +from typing import Any, Callable, Optional, TYPE_CHECKING + from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class Election(object): """Kazoo Basic Leader Election @@ -22,7 +30,12 @@ class Election(object): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: Optional[str] = None, + ): """Create a Kazoo Leader Election :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -34,7 +47,7 @@ def __init__(self, client, path, identifier=None): """ self.lock = client.Lock(path, identifier) - def run(self, func, *args, **kwargs): + def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: """Contend for the leadership This call will block until either this contender is cancelled @@ -57,7 +70,7 @@ def run(self, func, *args, **kwargs): except CancelledError: pass - def cancel(self): + def cancel(self) -> None: """Cancel participation in the election .. note:: @@ -68,7 +81,7 @@ def cancel(self): """ self.lock.cancel() - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders in the election diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index ce7fe567c..b43d00c70 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -5,12 +5,19 @@ :Status: Beta """ + +from __future__ import annotations + import datetime import json import socket +from typing import Any, Callable, Optional, TYPE_CHECKING from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class NonBlockingLease(object): """Exclusive lease that does not block. @@ -48,11 +55,11 @@ class NonBlockingLease(object): def __init__( self, - client, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + path: str, + duration: datetime.timedelta, + identifier: Optional[str] = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): """Create a non-blocking lease. @@ -71,7 +78,14 @@ def __init__( self.obtained = False self._attempt_obtaining(client, path, duration, ident, utcnow) - def _attempt_obtaining(self, client, path, duration, ident, utcnow): + def _attempt_obtaining( + self, + client: KazooClient, + path: str, + duration: datetime.timedelta, + ident: str, + utcnow: Callable[[], datetime.datetime], + ) -> None: client.ensure_path(path) holder_path = path + "/lease_holder" lock = client.Lock(path, ident) @@ -103,18 +117,18 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): except CancelledError: pass - def _encode(self, data_dict): + def _encode(self, data_dict: dict[str, Any]) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) - def _decode(self, raw): + def _decode(self, raw: bytes) -> dict[str, Any]: return json.loads(raw.decode(self._byte_encoding)) # Python 2.x - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.obtained # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained @@ -140,12 +154,12 @@ class MultiNonBlockingLease(object): def __init__( self, - client, - count, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + count: int, + path: str, + duration: datetime.timedelta, + identifier: Optional[str] = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): self.obtained = False for num in range(count): @@ -161,9 +175,9 @@ def __init__( break # Python 2.x - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.obtained # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 1f5247021..241aea86c 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -14,9 +14,21 @@ and/or the lease has been lost. """ + +from __future__ import annotations + import re import time import uuid +from typing import ( + Any, + Iterable, + Literal, + Optional, + Pattern, + Union, + TYPE_CHECKING, +) from kazoo.exceptions import ( CancelledError, @@ -31,20 +43,30 @@ RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient + class _Watch(object): - def __init__(self, duration=None): + def __init__(self, duration: Optional[float] = None): self.duration = duration - self.started_at = None + self.started_at: Optional[float] = None - def start(self): + def start(self) -> None: self.started_at = time.monotonic() - def leftover(self): + def leftover(self) -> Optional[float]: if self.duration is None: return None else: - elapsed = time.monotonic() - self.started_at + # We should probably set started_at to either 0 or + # time.monotonic() in __init__ to avoid the type ignore + # here, but this is a private class and it's pretty clear + # that start() should be called before leftover() so I'm + # not sure it's worth it. + elapsed = ( + time.monotonic() - self.started_at # type: ignore[operator] + ) return max(0, self.duration - elapsed) @@ -77,7 +99,13 @@ class Lock(object): # sequence number. Involved in read/write locks. _EXCLUDE_NAMES = ["__lock__"] - def __init__(self, client, path, identifier=None, extra_lock_patterns=()): + def __init__( + self, + client: KazooClient, + path: str, + identifier: Optional[str] = None, + extra_lock_patterns: Iterable[str] = (), + ): """Create a Kazoo lock. :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -97,10 +125,10 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): """ self.client = client self.path = path - self._exclude_names = set( + self._exclude_names: set[str] = set( self._EXCLUDE_NAMES + list(extra_lock_patterns) ) - self._contenders_re = re.compile( + self._contenders_re: Pattern[str] = re.compile( r"(?:{patterns})(-?\d{{10}})$".format( patterns="|".join(self._exclude_names) ) @@ -109,7 +137,7 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock self.data = str(identifier or "").encode("utf-8") - self.node = None + self.node: Optional[str] = None self.wake_event = client.handler.event_object() @@ -129,16 +157,21 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): ) self._acquire_method_lock = client.handler.lock_object() - def _ensure_path(self): + def _ensure_path(self) -> None: self.client.ensure_path(self.path) self.assured_path = True - def cancel(self): + def cancel(self) -> None: """Cancel a pending lock acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None, ephemeral=True): + def acquire( + self, + blocking: bool = True, + timeout: Optional[float] = None, + ephemeral: bool = True, + ) -> bool: """ Acquire the lock. By defaults blocks and waits forever. @@ -204,11 +237,16 @@ def acquire(self, blocking=True, timeout=None, ephemeral=True): finally: self._acquire_method_lock.release() - def _watch_session(self, state): + def _watch_session(self, state: Any) -> bool: self.wake_event.set() return True - def _inner_acquire(self, blocking, timeout, ephemeral=True): + def _inner_acquire( + self, + blocking: bool, + timeout: Optional[float], + ephemeral: bool = True, + ) -> bool: # wait until it's our chance to get it.. if self.is_acquired: if not blocking: @@ -219,7 +257,7 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): if not self.assured_path: self._ensure_path() - node = None + node: Optional[str] = None if self.create_tried: node = self._find_node() else: @@ -265,10 +303,10 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): finally: self.client.remove_listener(self._watch_session) - def _watch_predecessor(self, event): + def _watch_predecessor(self, event: Any) -> None: self.wake_event.set() - def _get_predecessor(self, node): + def _get_predecessor(self, node: str) -> Optional[str]: """returns `node`'s predecessor or None Note: This handle the case where the current lock is not a contender @@ -277,7 +315,7 @@ def _get_predecessor(self, node): """ node_sequence = node[len(self.prefix) :] children = self.client.get_children(self.path) - found_self = False + found_self: Union[Literal[False], None, re.Match[str]] = False # Filter out the contenders using the computed regex contender_matches = [] for child in children: @@ -308,17 +346,17 @@ def _get_predecessor(self, node): sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) return sorted_matches[-1].string - def _find_node(self): + def _find_node(self) -> Optional[str]: children = self.client.get_children(self.path) for child in children: if child.startswith(self.prefix): return child return None - def _delete_node(self, node): + def _delete_node(self, node: str) -> None: self.client.delete(self.path + "/" + node) - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: node = self.node or self._find_node() if node: @@ -326,16 +364,18 @@ def _best_effort_cleanup(self): except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lock immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: - self._delete_node(self.node) + # I don't think it's possible for self.node to be None here if + # self.is_acquired is true. + self._delete_node(self.node) # type: ignore[arg-type] except NoNodeError: # pragma: nocover pass @@ -343,7 +383,7 @@ def _inner_release(self): self.node = None return True - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders for the lock. @@ -390,10 +430,15 @@ def contenders(self): return contenders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Any, + exc_value: Any, + traceback: Any, + ) -> None: self.release() @@ -492,7 +537,13 @@ class Semaphore(object): """ - def __init__(self, client, path, identifier=None, max_leases=1): + def __init__( + self, + client: KazooClient, + path: str, + identifier: Optional[str] = None, + max_leases: int = 1, + ): """Create a Kazoo Lock :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -528,7 +579,7 @@ def __init__(self, client, path, identifier=None, max_leases=1): self.cancelled = False self._session_expired = False - def _ensure_path(self): + def _ensure_path(self) -> None: result = self.client.ensure_path(self.path) self.assured_path = True if result is True: @@ -549,12 +600,16 @@ def _ensure_path(self): else: self.client.set(self.path, str(self.max_leases).encode("utf-8")) - def cancel(self): + def cancel(self) -> None: """Cancel a pending semaphore acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None): + def acquire( + self, + blocking: bool = True, + timeout: Optional[float] = None, + ) -> bool: """Acquire the semaphore. By defaults blocks and waits forever. :param blocking: Block until semaphore is obtained or @@ -592,7 +647,11 @@ def acquire(self, blocking=True, timeout=None): return self.is_acquired - def _inner_acquire(self, blocking, timeout=None): + def _inner_acquire( + self, + blocking: bool, + timeout: Optional[float] = None, + ) -> bool: """Inner loop that runs from the top anytime a command hits a retryable Zookeeper exception.""" self._session_expired = False @@ -607,7 +666,12 @@ def _inner_acquire(self, blocking, timeout=None): w = _Watch(duration=timeout) w.start() - lock = self.client.Lock(self.lock_path, self.data) + # This is passing bytes data, but self.client.Lock expects a str, + # which I think is a bug in this code. However, I don't want to + # change any code at this point, so we just ignore the type error here. + lock = self.client.Lock( + self.lock_path, self.data # type: ignore[arg-type] + ) try: gotten = lock.acquire(blocking=blocking, timeout=w.leftover()) if not gotten: @@ -633,10 +697,10 @@ def _inner_acquire(self, blocking, timeout=None): finally: lock.release() - def _watch_lease_change(self, event): + def _watch_lease_change(self, event: Any) -> None: self.wake_event.set() - def _get_lease(self, data=None): + def _get_lease(self, data: Any = None) -> bool: # Make sure the session is still valid if self._session_expired: raise ForceRetryError("Retry on session loss at top") @@ -665,25 +729,26 @@ def _get_lease(self, data=None): # Return current state return self.is_acquired - def _watch_session(self, state): + def _watch_session(self, state: Any) -> Optional[bool]: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() # Return true to de-register return True + return None - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.delete(self.create_path) except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lease immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: @@ -693,7 +758,7 @@ def _inner_release(self): self.is_acquired = False return True - def lease_holders(self): + def lease_holders(self) -> list[str]: """Return an unordered list of the current lease holders. .. note:: @@ -716,8 +781,13 @@ def lease_holders(self): pass return lease_holders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Any, + exc_value: Any, + traceback: Any, + ) -> None: self.release() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 21dc6ef4a..8d916fd4e 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -17,20 +17,30 @@ so that no two workers own the same queue. """ + +from __future__ import annotations + from functools import partial import logging import os import socket +from enum import Enum +from typing import Any, Callable, Iterator, Optional, Sequence, TYPE_CHECKING from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState from kazoo.recipe.watchers import PatientChildrenWatch +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import IAsyncResult + from kazoo.recipe.lock import Lock log = logging.getLogger(__name__) -class PartitionState(object): +# This is a (str, Enum) for backwards compatibility. +class PartitionState(str, Enum): """High level partition state values .. attribute:: ALLOCATING @@ -139,14 +149,16 @@ class SetPartitioner(object): def __init__( self, - client, - path, - set, - partition_func=None, - identifier=None, - time_boundary=30, - max_reaction_time=1, - state_change_event=None, + client: KazooClient, + path: str, + set: Sequence[str], + partition_func: Optional[ + Callable[[str, list[str], Sequence[str]], list[str]] + ] = None, + identifier: Optional[str] = None, + time_boundary: float = 30, + max_reaction_time: float = 1, + state_change_event: Optional[Any] = None, ): """Create a :class:`~SetPartitioner` instance @@ -176,13 +188,13 @@ def __init__( self._client = client self._path = path self._set = set - self._partition_set = [] + self._partition_set: list[str] = [] self._partition_func = partition_func or self._partitioner self._identifier = identifier or "%s-%s" % ( socket.getfqdn(), os.getpid(), ) - self._locks = [] + self._locks: list[Lock] = [] self._lock_path = "/".join([path, "locks"]) self._party_path = "/".join([path, "party"]) self._time_boundary = time_boundary @@ -208,33 +220,33 @@ def __init__( # so we know when we're ready self._child_watching(self._allocate_transition, client_handler=True) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Return the partitions in this partition set""" for partition in self._partition_set: yield partition @property - def failed(self): + def failed(self) -> bool: """Corresponds to the :attr:`PartitionState.FAILURE` state""" return self.state == PartitionState.FAILURE @property - def release(self): + def release(self) -> bool: """Corresponds to the :attr:`PartitionState.RELEASE` state""" return self.state == PartitionState.RELEASE @property - def allocating(self): + def allocating(self) -> bool: """Corresponds to the :attr:`PartitionState.ALLOCATING` state""" return self.state == PartitionState.ALLOCATING @property - def acquired(self): + def acquired(self) -> bool: """Corresponds to the :attr:`PartitionState.ACQUIRED` state""" return self.state == PartitionState.ACQUIRED - def wait_for_acquire(self, timeout=30): + def wait_for_acquire(self, timeout: float = 30) -> None: """Wait for the set to be partitioned and acquired :param timeout: How long to wait before returning. @@ -243,7 +255,7 @@ def wait_for_acquire(self, timeout=30): """ self._acquire_event.wait(timeout) - def release_set(self): + def release_set(self) -> None: """Call to release the set This method begins the step of allocating once the set has @@ -263,12 +275,12 @@ def release_set(self): self._set_state(PartitionState.ALLOCATING) self._child_watching(self._allocate_transition, client_handler=True) - def finish(self): + def finish(self) -> None: """Call to release the set and leave the party""" self._release_locks() self._fail_out() - def _fail_out(self): + def _fail_out(self) -> None: with self._state_change: self._set_state(PartitionState.FAILURE) if self._party.participating: @@ -277,7 +289,7 @@ def _fail_out(self): except KazooException: # pragma: nocover pass - def _allocate_transition(self, result): + def _allocate_transition(self, result: Any) -> None: """Called when in allocating mode, and the children settled""" # Did we get an exception waiting for children to settle? @@ -288,7 +300,7 @@ def _allocate_transition(self, result): children, async_result = result.get() children_changed = self._client.handler.event_object() - def updated(result): + def updated(result: IAsyncResult) -> None: with self._state_change: children_changed.set() if self.acquired: @@ -307,7 +319,7 @@ def updated(result): # Check whether the state has changed during the lock acquisition # and abort the process if so. - def abort_if_needed(): + def abort_if_needed() -> bool: if self.state_id == state_id: if children_changed.is_set(): # The party has changed. Repartitioning... @@ -365,7 +377,7 @@ def abort_if_needed(): # This mustn't happen. Means a logical error. self._fail_out() - def _release_locks(self): + def _release_locks(self) -> None: """Attempt to completely remove all the locks""" self._acquire_event.clear() for lock in self._locks[:]: @@ -378,7 +390,7 @@ def _release_locks(self): else: self._locks.remove(lock) - def _abort_lock_acquisition(self): + def _abort_lock_acquisition(self) -> None: """Called during lock acquisition if a party change occurs""" self._release_locks() @@ -391,7 +403,11 @@ def _abort_lock_acquisition(self): self._child_watching(self._allocate_transition, client_handler=True) - def _child_watching(self, func=None, client_handler=False): + def _child_watching( + self, + func: Optional[Callable[..., Any]] = None, + client_handler: bool = False, + ) -> Any: """Called when children are being watched to stabilize This actually returns immediately, child watcher spins up a @@ -414,7 +430,7 @@ def _child_watching(self, func=None, client_handler=False): asy.rawlink(func) return asy - def _establish_sessionwatch(self, state): + def _establish_sessionwatch(self, state: Any) -> bool: """Register ourself to listen for session events, we shut down if we become lost""" with self._state_change: @@ -427,7 +443,12 @@ def _establish_sessionwatch(self, state): return state == KazooState.LOST - def _partitioner(self, identifier, members, partitions): + def _partitioner( + self, + identifier: str, + members: list[str], + partitions: Sequence[str], + ) -> list[str]: # Ensure consistent order of partitions/members all_partitions = sorted(partitions) workers = sorted(members) @@ -437,7 +458,7 @@ def _partitioner(self, identifier, members, partitions): # skipping the other workers return all_partitions[i :: len(workers)] - def _set_state(self, state): + def _set_state(self, state: PartitionState) -> None: self.state = state self.state_id += 1 self.state_change_event.set() diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index 2a0f5dfb6..baf517f27 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -7,15 +7,27 @@ used for determining members of a party. """ + +from __future__ import annotations + import uuid +from typing import Any, Iterator, Optional, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseParty(object): """Base implementation of a party.""" - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: Optional[str] = None, + ): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The party path to use. @@ -29,44 +41,51 @@ def __init__(self, client, path, identifier=None): self.ensured_path = False self.participating = False - def _ensure_parent(self): + def _ensure_parent(self) -> None: if not self.ensured_path: # make sure our parent node exists self.client.ensure_path(self.path) self.ensured_path = True - def join(self): + def join(self) -> Any: """Join the party""" return self.client.retry(self._inner_join) - def _inner_join(self): + def _inner_join(self) -> None: self._ensure_parent() try: - self.client.create(self.create_path, self.data, ephemeral=True) + # This and the #type: ignore[attr-defined] below could be removed + # by setting up create_path in the constructor but trying to avoid + # changing the code too much + self.client.create( + self.create_path, # type: ignore[attr-defined] + self.data, + ephemeral=True, + ) self.participating = True except NodeExistsError: # node was already created, perhaps we are recovering from a # suspended connection self.participating = True - def leave(self): + def leave(self) -> Any: """Leave the party""" self.participating = False return self.client.retry(self._inner_leave) - def _inner_leave(self): + def _inner_leave(self) -> bool: try: - self.client.delete(self.create_path) + self.client.delete(self.create_path) # type: ignore[attr-defined] except NoNodeError: return False return True - def __len__(self): + def __len__(self) -> int: """Return a count of participating clients""" self._ensure_parent() return len(self._get_children()) - def _get_children(self): + def _get_children(self) -> list[str]: return self.client.retry(self.client.get_children, self.path) @@ -75,12 +94,14 @@ class Party(BaseParty): _NODE_NAME = "__party__" - def __init__(self, client, path, identifier=None): + def __init__( + self, client: KazooClient, path: str, identifier: Optional[str] = None + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = uuid.uuid4().hex + self._NODE_NAME self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' data values""" self._ensure_parent() children = self._get_children() @@ -93,7 +114,7 @@ def __iter__(self): except NoNodeError: # pragma: nocover pass - def _get_children(self): + def _get_children(self) -> list[str]: children = BaseParty._get_children(self) return [c for c in children if self._NODE_NAME in c] @@ -109,12 +130,17 @@ class ShallowParty(BaseParty): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: Optional[str] = None, + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = "-".join([uuid.uuid4().hex, self.data.decode("utf-8")]) self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' identifiers""" self._ensure_parent() children = self._get_children() diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index 30d3066e4..c79d2dd71 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -9,17 +9,24 @@ See: https://github.com/python-zk/kazoo/issues/175 """ + +from __future__ import annotations + import uuid +from typing import Any, Optional, TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError from kazoo.protocol.states import EventType from kazoo.retry import ForceRetryError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseQueue(object): """A common base class for queue implementations.""" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. @@ -27,10 +34,10 @@ def __init__(self, client, path): self.client = client self.path = path self._entries_path = path - self.structure_paths = (self.path,) + self.structure_paths: tuple[str, ...] = (self.path,) self.ensured_path = False - def _check_put_arguments(self, value, priority=100): + def _check_put_arguments(self, value: bytes, priority: int = 100) -> None: if not isinstance(value, bytes): raise TypeError("value must be a byte string") if not isinstance(priority, int): @@ -38,14 +45,14 @@ def _check_put_arguments(self, value, priority=100): elif priority < 0 or priority > 999: raise ValueError("priority must be between 0 and 999") - def _ensure_paths(self): + def _ensure_paths(self) -> None: if not self.ensured_path: # make sure our parent / internal structure nodes exists for path in self.structure_paths: self.client.ensure_path(path) self.ensured_path = True - def __len__(self): + def __len__(self) -> int: self._ensure_paths() _, stat = self.client.retry(self.client.get, self._entries_path) return stat.children_count @@ -62,19 +69,19 @@ class Queue(BaseQueue): prefix = "entry-" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(Queue, self).__init__(client, path) - self._children = [] + self._children: list[str] = [] - def __len__(self): + def __len__(self) -> int: """Return queue size.""" return super(Queue, self).__len__() - def get(self): + def get(self) -> Optional[bytes]: """ Get item data and remove an item from the queue. @@ -84,7 +91,7 @@ def get(self): self._ensure_paths() return self.client.retry(self._inner_get) - def _inner_get(self): + def _inner_get(self) -> Optional[bytes]: if not self._children: self._children = self.client.retry( self.client.get_children, self.path @@ -105,7 +112,7 @@ def _inner_get(self): self._children.pop(0) return data - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an item into the queue. :param value: Byte string to put into the queue. @@ -150,26 +157,26 @@ class LockingQueue(BaseQueue): entries = "/entries" entry = "entry" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(LockingQueue, self).__init__(client, path) self.id = uuid.uuid4().hex.encode() - self.processing_element = None + self.processing_element: Optional[tuple[str, bytes]] = None self._lock_path = self.path + self.lock self._entries_path = self.path + self.entries self.structure_paths = (self._lock_path, self._entries_path) - def __len__(self): + def __len__(self) -> int: """Returns the current length of the queue. :returns: queue size (includes locked entries count). """ return super(LockingQueue, self).__len__() - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an entry into the queue. :param value: Byte string to put into the queue. @@ -189,7 +196,7 @@ def put(self, value, priority=100): sequence=True, ) - def put_all(self, values, priority=100): + def put_all(self, values: list[bytes], priority: int = 100) -> None: """Put several entries into the queue. The action only succeeds if all entries where put into the queue. @@ -221,7 +228,7 @@ def put_all(self, values, priority=100): sequence=True, ) - def get(self, timeout=None): + def get(self, timeout: Optional[float] = None) -> Optional[bytes]: """Locks and gets an entry from the queue. If a previously got entry was not consumed, this method will return that entry. @@ -237,7 +244,7 @@ def get(self, timeout=None): else: return self._inner_get(timeout) - def holds_lock(self): + def holds_lock(self) -> bool: """Checks if a node still holds the lock. :returns: True if a node still holds the lock, False otherwise. @@ -251,7 +258,7 @@ def holds_lock(self): value, stat = self.client.retry(self.client.get, lock_path) return value == self.id - def consume(self): + def consume(self) -> bool: """Removes a currently processing entry from the queue. :returns: True if element was removed successfully, False otherwise. @@ -271,7 +278,7 @@ def consume(self): else: return False - def release(self): + def release(self) -> bool: """Removes the lock from currently processed item without consuming it. :returns: True if the lock was removed successfully, False otherwise. @@ -289,13 +296,13 @@ def release(self): else: return False - def _inner_get(self, timeout): + def _inner_get(self, timeout: Optional[float]) -> Optional[bytes]: flag = self.client.handler.event_object() lock = self.client.handler.lock_object() canceled = False value = [] - def check_for_updates(event): + def check_for_updates(event: Optional[Any]) -> None: if event is not None and event.type != EventType.CHILD: return with lock: @@ -330,8 +337,8 @@ def check_for_updates(event): retVal = value[0][1] return retVal - def _filter_locked(self, values, taken): - taken = set(taken) + def _filter_locked(self, values: list[str], taken: list[str]) -> list[str]: + taken = set(taken) # type: ignore[assignment] available = sorted(values) return ( available @@ -339,7 +346,7 @@ def _filter_locked(self, values, taken): else [x for x in available if x not in taken] ) - def _take(self, id_): + def _take(self, id_: str) -> Optional[tuple[str, bytes]]: try: self.client.create( "{path}/{id}".format(path=self._lock_path, id=id_), diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index d4cb0300e..d45e01acb 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -10,15 +10,22 @@ will result in an exception being thrown. """ + +from __future__ import annotations + from functools import partial, wraps import logging import time import warnings +from typing import Any, List, Callable, Optional, Union, TYPE_CHECKING from kazoo.exceptions import ConnectionClosedError, NoNodeError, KazooException -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat from kazoo.retry import KazooRetry +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import IAsyncResult log = logging.getLogger(__name__) @@ -26,9 +33,9 @@ _STOP_WATCHING = object() -def _ignore_closed(func): +def _ignore_closed(func: Callable[..., None]) -> Callable[..., None]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> None: try: return func(*args, **kwargs) except ConnectionClosedError: @@ -37,6 +44,15 @@ def wrapper(*args, **kwargs): return wrapper +DataWatchFunc = Union[ + Callable[[Optional[str], Optional[ZnodeStat]], Optional[bool]], + Callable[ + [Optional[str], Optional[ZnodeStat], Optional[WatchedEvent]], + Optional[bool], + ], +] + + class DataWatch(object): """Watches a node for data updates and calls the specified function each time it changes @@ -88,7 +104,14 @@ def my_func(data, stat, event): """ - def __init__(self, client, path, func=None, *args, **kwargs): + def __init__( + self, + client: KazooClient, + path: str, + func: Optional[DataWatchFunc] = None, + *args: Any, + **kwargs: Any, + ): """Create a data watcher for a path :param client: A zookeeper client. @@ -107,7 +130,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._func = func self._stopped = False self._run_lock = client.handler.lock_object() - self._version = None + self._version: Optional[int] = None self._retry = KazooRetry( max_tries=None, sleep_func=client.handler.sleep_func ) @@ -132,7 +155,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._client.add_listener(self._session_watcher) self._get_data() - def __call__(self, func): + def __call__(self, func: DataWatchFunc) -> DataWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -155,16 +178,27 @@ def __call__(self, func): self._get_data() return func - def _log_func_exception(self, data, stat, event=None): + def _log_func_exception( + self, + data: Any, + stat: Optional[ZnodeStat], + event: Optional[WatchedEvent] = None, + ) -> None: try: # For backwards compatibility, don't send event to the # callback unless the send_event is set in constructor if not self._ever_called: self._ever_called = True try: - result = self._func(data, stat, event) + # The type ignores here are because mypy can't figure out that + # 1) self._func can't ever be None (fingers crossed) + # 2) the function can be called with 2 arguments or with 3 + # arguments + result = self._func( # type: ignore[call-arg, misc] + data, stat, event + ) except TypeError: - result = self._func(data, stat) + result = self._func(data, stat) # type: ignore[call-arg, misc] if result is False: self._stopped = True self._func = None @@ -174,7 +208,7 @@ def _log_func_exception(self, data, stat, event=None): raise @_ignore_closed - def _get_data(self, event=None): + def _get_data(self, event: Optional[WatchedEvent] = None) -> None: # Ensure this runs one at a time, possible because the session # watcher may trigger a run with self._run_lock: @@ -183,6 +217,7 @@ def _get_data(self, event=None): initial_version = self._version + stat: Optional[ZnodeStat] try: data, stat = self._retry( self._client.get, self._path, self._watcher @@ -210,18 +245,24 @@ def _get_data(self, event=None): if initial_version != self._version or not self._ever_called: self._log_func_exception(data, stat, event) - def _watcher(self, event): + def _watcher(self, event: KazooState) -> None: self._get_data(event=event) - def _set_watch(self, state): + def _set_watch(self, state: KazooState) -> None: with self._run_lock: self._watch_established = state - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.CONNECTED: self._client.handler.spawn(self._get_data) +ChildrenWatchFunc = Union[ + Callable[[List[str]], Optional[bool]], + Callable[[List[str], Optional[WatchedEvent]], Optional[bool]], +] + + class ChildrenWatch(object): """Watches a node for children updates and calls the specified function each time it changes @@ -253,11 +294,11 @@ def my_func(children): def __init__( self, - client, - path, - func=None, - allow_session_lost=True, - send_event=False, + client: KazooClient, + path: str, + func: Optional[ChildrenWatchFunc] = None, + allow_session_lost: bool = True, + send_event: bool = False, ): """Create a children watcher for a path @@ -290,7 +331,7 @@ def __init__( self._watch_established = False self._allow_session_lost = allow_session_lost self._run_lock = client.handler.lock_object() - self._prior_children = None + self._prior_children: Optional[List[str]] = None self._used = False # Register our session listener if we're going to resume @@ -301,7 +342,7 @@ def __init__( self._client.add_listener(self._session_watcher) self._get_children() - def __call__(self, func): + def __call__(self, func: ChildrenWatchFunc) -> ChildrenWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -325,7 +366,7 @@ def __call__(self, func): return func @_ignore_closed - def _get_children(self, event=None): + def _get_children(self, event: Optional[WatchedEvent] = None) -> None: with self._run_lock: # Ensure this runs one at a time if self._stopped: return @@ -351,9 +392,16 @@ def _get_children(self, event=None): try: if self._send_event: - result = self._func(children, event) + # See comment about the type ignore here in DataWatch, + # it's the same issue where mypy can't figure out that the + # function can be called with 1 argument or with 2 + result = self._func( # type: ignore[misc] + children, event # type: ignore[call-arg] + ) else: - result = self._func(children) + result = self._func( # type: ignore[misc] + children # type: ignore[call-arg] + ) if result is False: self._stopped = True self._func = None @@ -363,11 +411,11 @@ def _get_children(self, event=None): log.exception(exc) raise - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: if event.type != "NONE": self._get_children(event) - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state in (KazooState.LOST, KazooState.SUSPENDED): self._watch_established = False elif ( @@ -408,14 +456,16 @@ class PatientChildrenWatch(object): """ - def __init__(self, client, path, time_boundary=30): + def __init__( + self, client: KazooClient, path: str, time_boundary: float = 30 + ): self.client = client self.path = path - self.children = [] + self.children: list[str] = [] self.time_boundary = time_boundary self.children_changed = client.handler.event_object() - def start(self): + def start(self) -> IAsyncResult: """Begin the watching process asynchronously :returns: An :class:`~kazoo.interfaces.IAsyncResult` instance @@ -427,7 +477,7 @@ def start(self): self.client.handler.spawn(self._inner_start) return asy - def _inner_start(self): + def _inner_start(self) -> None: try: while True: async_result = self.client.handler.async_result() @@ -447,6 +497,8 @@ def _inner_start(self): except Exception as exc: self.asy.set_exception(exc) - def _children_watcher(self, async_result, event): + def _children_watcher( + self, async_result: IAsyncResult, event: WatchedEvent + ) -> None: self.children_changed.set() async_result.set(time.monotonic()) diff --git a/kazoo/retry.py b/kazoo/retry.py index fb9e8fc7b..c8e0bcc2b 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import random import time +from typing import Any, Callable, Optional, TypeVar from kazoo.exceptions import ( ConnectionClosedError, @@ -10,7 +13,6 @@ SessionExpiredError, ) - log = logging.getLogger(__name__) @@ -37,18 +39,20 @@ class KazooRetry(object): EXPIRED_EXCEPTIONS = (SessionExpiredError,) + RETRY_RETURN = TypeVar("RETRY_RETURN") + def __init__( self, - max_tries=1, - delay=0.1, - backoff=2, - max_jitter=0.4, - max_delay=60.0, - ignore_expire=True, - sleep_func=time.sleep, - deadline=None, - interrupt=None, - ): + max_tries: Optional[int] = 1, + delay: float = 0.1, + backoff: int = 2, + max_jitter: float = 0.4, + max_delay: float = 60.0, + ignore_expire: bool = True, + sleep_func: Callable[[float], None] = time.sleep, + deadline: Optional[float] = None, + interrupt: Optional[Callable[[], bool]] = None, + ) -> None: """Create a :class:`KazooRetry` instance for retrying function calls. @@ -81,20 +85,22 @@ def __init__( self._attempts = 0 self._cur_delay = delay self.deadline = deadline - self._cur_stoptime = None + self._cur_stoptime: Optional[float] = None self.sleep_func = sleep_func - self.retry_exceptions = self.RETRY_EXCEPTIONS + self.retry_exceptions: tuple[ + type[Exception], ... + ] = self.RETRY_EXCEPTIONS self.interrupt = interrupt if ignore_expire: self.retry_exceptions += self.EXPIRED_EXCEPTIONS - def reset(self): + def reset(self) -> None: """Reset the attempt counter""" self._attempts = 0 self._cur_delay = self.delay self._cur_stoptime = None - def copy(self): + def copy(self) -> KazooRetry: """Return a clone of this retry manager""" obj = KazooRetry( max_tries=self.max_tries, @@ -109,7 +115,9 @@ def copy(self): obj.retry_exceptions = self.retry_exceptions return obj - def __call__(self, func, *args, **kwargs): + def __call__( + self, func: Callable[..., RETRY_RETURN], *args: Any, **kwargs: Any + ) -> RETRY_RETURN: """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/kazoo/security.py b/kazoo/security.py index 683994451..1b383795f 100644 --- a/kazoo/security.py +++ b/kazoo/security.py @@ -1,14 +1,19 @@ """Kazoo Security""" + +from __future__ import annotations + from base64 import b64encode -from collections import namedtuple import hashlib +from typing import NamedTuple # Represents a Zookeeper ID and ACL object -Id = namedtuple("Id", "scheme id") +class Id(NamedTuple): + scheme: str + id: str -class ACL(namedtuple("ACL", "perms id")): +class ACL(NamedTuple): """An ACL for a Zookeeper Node An ACL object is created by using an :class:`Id` object along with @@ -17,8 +22,11 @@ class ACL(namedtuple("ACL", "perms id")): the desired scheme, id, and permissions. """ + perms: int + id: Id + @property - def acl_list(self): + def acl_list(self) -> list[str]: perms = [] if self.perms & Permissions.ALL == Permissions.ALL: perms.append("ALL") @@ -35,7 +43,7 @@ def acl_list(self): perms.append("ADMIN") return perms - def __repr__(self): + def __repr__(self) -> str: return "ACL(perms=%r, acl_list=%s, id=%r)" % ( self.perms, self.acl_list, @@ -62,7 +70,7 @@ class Permissions(object): READ_ACL_UNSAFE = [ACL(Permissions.READ, ANYONE_ID_UNSAFE)] -def make_digest_acl_credential(username, password): +def make_digest_acl_credential(username: str, password: str) -> str: """Create a SHA1 digest credential. .. note:: @@ -80,15 +88,15 @@ def make_digest_acl_credential(username, password): def make_acl( - scheme, - credential, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + scheme: str, + credential: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Given a scheme and credential, return an :class:`ACL` object appropriate for use with Kazoo. @@ -131,15 +139,15 @@ def make_acl( def make_digest_acl( - username, - password, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + username: str, + password: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Create a digest ACL for Zookeeper with the given permissions This method combines :meth:`make_digest_acl_credential` and diff --git a/pyproject.toml b/pyproject.toml index db3890c58..5af3c0d4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,11 @@ timeout = 180 ignore_missing_imports = false # Disallow dynamic typing -disallow_any_unimported = true +disallow_any_unimported = false # true disallow_any_expr = false -disallow_any_decorated = true -disallow_any_explicit = true -disallow_any_generics = true +disallow_any_decorated = false # true +disallow_any_explicit = false # true +disallow_any_generics = false # true disallow_subclassing_any = true # Untyped definitions and calls @@ -90,33 +90,20 @@ verbosity = 0 # FIXME: As type annotations are introduced, please remove the appropriate # ignore_errors flag below. New modules should NOT be added here! +# no-any-return - We still have some imported modules with no type annotations, +# and I want to avoid code changes as much as possible. + +# unused-ignore This is a temporary workaround for the fact that mypy can +# produce different errors in 3.8 and 3.14, and I want to avoid code changes +# as much as possible. + +disable_error_code = [ + 'no-any-return', + 'unused-ignore', +] + [[tool.mypy.overrides]] module = [ - 'kazoo.client', - 'kazoo.exceptions', - 'kazoo.handlers.eventlet', - 'kazoo.handlers.gevent', - 'kazoo.handlers.threading', - 'kazoo.handlers.utils', - 'kazoo.hosts', - 'kazoo.interfaces', - 'kazoo.loggingsupport', - 'kazoo.protocol.connection', - 'kazoo.protocol.paths', - 'kazoo.protocol.serialization', - 'kazoo.protocol.states', - 'kazoo.recipe.barrier', - 'kazoo.recipe.cache', - 'kazoo.recipe.counter', - 'kazoo.recipe.election', - 'kazoo.recipe.lease', - 'kazoo.recipe.lock', - 'kazoo.recipe.partitioner', - 'kazoo.recipe.party', - 'kazoo.recipe.queue', - 'kazoo.recipe.watchers', - 'kazoo.retry', - 'kazoo.security', 'kazoo.testing.common', 'kazoo.testing.harness', 'kazoo.tests.conftest', @@ -146,6 +133,20 @@ module = [ 'kazoo.tests.test_utils', 'kazoo.tests.test_watchers', 'kazoo.tests.util', - 'kazoo.version' ] ignore_errors = true + +[[tool.mypy.overrides]] + module = ["eventlet.*"] + ignore_missing_imports = true + #follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["gevent.thread"] + ignore_missing_imports = true + #follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["puresasl.*"] + ignore_missing_imports = true + #follow_untyped_imports = true diff --git a/setup.cfg b/setup.cfg index 2e771f176..89e8e8442 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,6 +81,7 @@ docs = typing = mypy>=0.991 + types-gevent alldeps = %(dev)s From 1ed573bba5082166e385a22132c67d60ae00f14a Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Wed, 29 Apr 2026 22:58:57 +0100 Subject: [PATCH 2/8] Replace Optional and Union wherever possible --- kazoo/client.py | 99 ++++++++++++++++----------------- kazoo/handlers/utils.py | 30 +++++----- kazoo/hosts.py | 5 +- kazoo/interfaces.py | 10 ++-- kazoo/protocol/connection.py | 46 +++++++-------- kazoo/protocol/serialization.py | 46 +++++++-------- kazoo/protocol/states.py | 6 +- kazoo/recipe/barrier.py | 6 +- kazoo/recipe/cache.py | 18 +++--- kazoo/recipe/counter.py | 6 +- kazoo/recipe/election.py | 4 +- kazoo/recipe/lease.py | 6 +- kazoo/recipe/lock.py | 32 +++++------ kazoo/recipe/partitioner.py | 13 ++--- kazoo/recipe/party.py | 11 ++-- kazoo/recipe/queue.py | 16 +++--- kazoo/recipe/watchers.py | 18 +++--- kazoo/retry.py | 10 ++-- 18 files changed, 188 insertions(+), 194 deletions(-) diff --git a/kazoo/client.py b/kazoo/client.py index ea9f5856e..1e195b249 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -16,7 +16,6 @@ Optional, Sequence, Set, - Union, overload, TYPE_CHECKING, ) @@ -126,22 +125,22 @@ class KazooClient(object): def __init__( self, - hosts: Union[str, list[str]] = "127.0.0.1:2181", + hosts: str | list[str] = "127.0.0.1:2181", timeout: float = 10.0, - client_id: Optional[tuple] = None, - handler: Optional[IHandler] = None, - default_acl: Optional[Sequence[ACL]] = None, - auth_data: Optional[set] = None, - sasl_options: Optional[dict] = None, - read_only: Optional[bool] = None, + client_id: tuple | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: set | None = None, + sasl_options: dict | None = None, + read_only: bool | None = None, randomize_hosts: bool = True, - connection_retry: Optional[Union[KazooRetry, dict]] = None, - command_retry: Optional[Union[KazooRetry, dict]] = None, - logger: Optional[logging.Logger] = None, - keyfile: Optional[str] = None, - keyfile_password: Optional[str] = None, - certfile: Optional[str] = None, - ca: Optional[str] = None, + connection_retry: KazooRetry | dict | None = None, + command_retry: KazooRetry | dict | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, use_ssl: bool = False, verify_certs: bool = True, check_hostname: bool = False, @@ -498,7 +497,7 @@ def client_state(self) -> KeeperState: return self._state @property - def client_id(self) -> Optional[tuple]: + def client_id(self) -> tuple | None: """Returns the client id for this Zookeeper session if connected. @@ -518,8 +517,8 @@ def connected(self) -> bool: def set_hosts( self, - hosts: Union[str, list[str]], - randomize_hosts: Optional[bool] = None, + hosts: str | list[str], + randomize_hosts: bool | None = None, ) -> None: """sets the list of hosts used by this client. @@ -686,7 +685,7 @@ def _safe_close(self) -> None: def _call( self, request: object, async_object: IAsyncResult - ) -> Optional[bool]: + ) -> bool | None: """Ensure the client is in CONNECTED or SUSPENDED state and put the request in the queue if it is. @@ -874,7 +873,7 @@ def server_version(self, retries: int = 3) -> tuple: """ - def _try_fetch() -> Optional[tuple[int, ...]]: + def _try_fetch() -> tuple[int, ...] | None: data = self.command(b"envi") data_parsed = {} for line in data.splitlines(): @@ -900,7 +899,7 @@ def _try_fetch() -> Optional[tuple[int, ...]]: except ValueError: return None - def _is_valid(version: Optional[tuple[int, ...]]) -> bool: + def _is_valid(version: tuple[int, ...] | None) -> bool: # All zookeeper versions should have at least major.minor # version numbers; if we get one that doesn't it is likely not # correct and was truncated... @@ -1017,7 +1016,7 @@ def create( self, path: str, value: bytes = b"", - acl: Optional[Sequence[ACL]] = None, + acl: Sequence[ACL] | None = None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, @@ -1030,7 +1029,7 @@ def create( self, path: str, value: bytes = b"", - acl: Optional[Sequence[ACL]] = None, + acl: Sequence[ACL] | None = None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, @@ -1042,12 +1041,12 @@ def create( self, path: str, value: bytes = b"", - acl: Optional[Sequence[ACL]] = None, + acl: Sequence[ACL] | None = None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False, - ) -> Union[str, tuple[str, ZnodeStat]]: + ) -> str | tuple[str, ZnodeStat]: """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1140,7 +1139,7 @@ def create_async( self, path: str, value: bytes = b"", - acl: Optional[Sequence[ACL]] = None, + acl: Sequence[ACL] | None = None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, @@ -1213,7 +1212,7 @@ def retry_completion(result: IAsyncResult) -> None: @wrap(async_result) def create_completion( result: IAsyncResult, - ) -> Optional[Union[str, tuple[str, ZnodeStat]]]: + ) -> str | tuple[str, ZnodeStat] | None: try: if include_data: new_path, stat = result.get() @@ -1266,9 +1265,7 @@ def _create_async_inner( raise async_result.exception # type: ignore[misc] return async_result - def ensure_path( - self, path: str, acl: Optional[Sequence[ACL]] = None - ) -> bool: + def ensure_path(self, path: str, acl: Sequence[ACL] | None = None) -> bool: """Recursively create a path if it doesn't exist. :param path: Path of node. @@ -1278,7 +1275,7 @@ def ensure_path( return self.ensure_path_async(path, acl).get() def ensure_path_async( - self, path: str, acl: Optional[Sequence[ACL]] = None + self, path: str, acl: Sequence[ACL] | None = None ) -> IAsyncResult: """Recursively create a path asynchronously if it doesn't exist. Takes the same arguments as :meth:`ensure_path`. @@ -1306,7 +1303,7 @@ def prepare_completion(next_path: str, result: IAsyncResult) -> None: @wrap(async_result) def exists_completion( path: str, result: IAsyncResult - ) -> Optional[Literal[True]]: + ) -> Literal[True] | None: if result.get(): return True parent, node = split(path) @@ -1323,8 +1320,8 @@ def exists_completion( return async_result def exists( - self, path: str, watch: Optional[WatchFunc] = None - ) -> Optional[ZnodeStat]: + self, path: str, watch: WatchFunc | None = None + ) -> ZnodeStat | None: """Check if a node exists. If a watch is provided, it will be left on the node with the @@ -1347,7 +1344,7 @@ def exists( return self.exists_async(path, watch=watch).get() def exists_async( - self, path: str, watch: Optional[WatchFunc] = None + self, path: str, watch: WatchFunc | None = None ) -> IAsyncResult: """Asynchronously check if a node exists. Takes the same arguments as :meth:`exists`. @@ -1367,7 +1364,7 @@ def exists_async( return async_result def get( - self, path: str, watch: Optional[WatchFunc] = None + self, path: str, watch: WatchFunc | None = None ) -> tuple[bytes, ZnodeStat]: """Get the value of a node. @@ -1394,7 +1391,7 @@ def get( return self.get_async(path, watch=watch).get() def get_async( - self, path: str, watch: Optional[WatchFunc] = None + self, path: str, watch: WatchFunc | None = None ) -> IAsyncResult: """Asynchronously get the value of a node. Takes the same arguments as :meth:`get`. @@ -1416,7 +1413,7 @@ def get_async( def get_children( self, path: str, - watch: Optional[WatchFunc] = None, + watch: WatchFunc | None = None, include_data: bool = False, ) -> list[str]: """Get a list of child nodes of a path. @@ -1459,7 +1456,7 @@ def get_children( def get_children_async( self, path: str, - watch: Optional[WatchFunc] = None, + watch: WatchFunc | None = None, include_data: bool = False, ) -> IAsyncResult: """Asynchronously get a list of child nodes of a path. Takes @@ -1476,7 +1473,7 @@ def get_children_async( raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() - req: Union[GetChildren, GetChildren2] + req: GetChildren | GetChildren2 if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1575,7 +1572,7 @@ def set_acls_async( return async_result def set( - self, path: str, value: Optional[bytes], version: int = -1 + self, path: str, value: bytes | None, version: int = -1 ) -> ZnodeStat: """Set the value of a node. @@ -1612,7 +1609,7 @@ def set( return self.set_async(path, value, version).get() def set_async( - self, path: str, value: Optional[bytes], version: int = -1 + self, path: str, value: bytes | None, version: int = -1 ) -> IAsyncResult: """Set the value of a node. Takes the same arguments as :meth:`set`. @@ -1711,7 +1708,7 @@ def delete_async(self, path: str, version: int = -1) -> IAsyncResult: ) return async_result - def _delete_recursive(self, path: str) -> Optional[Literal[True]]: + def _delete_recursive(self, path: str) -> Literal[True] | None: try: children = self.get_children(path) except NoNodeError: @@ -1733,9 +1730,9 @@ def _delete_recursive(self, path: str) -> Optional[Literal[True]]: def reconfig( self, - joining: Optional[str], - leaving: Optional[str], - new_members: Optional[str], + joining: str | None, + leaving: str | None, + new_members: str | None, from_config: int = -1, ) -> tuple[bytes, ZnodeStat]: """Reconfig a cluster. @@ -1812,9 +1809,9 @@ def reconfig( def reconfig_async( self, - joining: Optional[str], - leaving: Optional[str], - new_members: Optional[str], + joining: str | None, + leaving: str | None, + new_members: str | None, from_config: int, ) -> IAsyncResult: """Asynchronously reconfig a cluster. Takes the same arguments as @@ -1872,7 +1869,7 @@ def create( self, path: str, value: bytes = b"", - acl: Optional[Sequence[ACL]] = None, + acl: Sequence[ACL] | None = None, ephemeral: bool = False, sequence: bool = False, ) -> None: @@ -1990,7 +1987,7 @@ def _check_tx_state(self) -> None: def _add( self, request: Any, - post_processor: Optional[Callable[[Any], Any]] = None, + post_processor: Callable[[Any], Any] | None = None, ) -> None: self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 687730273..cbba1b9e1 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -11,7 +11,7 @@ import socket import time from types import ModuleType -from typing import Any, Callable, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from kazoo.interfaces import IAsyncResult @@ -45,7 +45,7 @@ def __init__( timeout_factory: Callable[[], Any], ) -> None: self._handler = handler - self._exception: Union[object, None, Exception] = _NONE + self._exception: object | Exception | None = _NONE self._condition = condition_factory() self._callbacks: list[CallbackFunc] = [] self._timeout_factory = timeout_factory @@ -61,7 +61,7 @@ def successful(self) -> bool: return self._exception is None @property - def exception(self) -> Optional[Exception]: + def exception(self) -> Exception | None: if self._exception is not _NONE: # The next line should have return-value, but hound ci # is frankly nothing but a hound dog @@ -83,7 +83,7 @@ def set_exception(self, exception: Exception) -> None: self._do_callbacks() self._condition.notify_all() - def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception. If there is no value raises TimeoutError. @@ -112,7 +112,7 @@ def get_nowait(self) -> Any: """ return self.get(block=False) - def wait(self, timeout: Optional[float] = None) -> bool: + def wait(self, timeout: float | None = None) -> bool: """Block until the instance is ready.""" with self._condition: if not self.ready(): @@ -222,17 +222,17 @@ def create_tcp_socket(module: ModuleType) -> Socket: def create_tcp_connection( module: ModuleType, address: Any, - hostname: Optional[str] = None, - timeout: Optional[float] = None, + hostname: str | None = None, + timeout: float | None = None, use_ssl: bool = False, - ca: Optional[str] = None, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - keyfile_password: Optional[str] = None, + ca: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, verify_certs: bool = True, check_hostname: bool = False, - options: Optional[ssl.Options] = None, - ciphers: Optional[str] = None, + options: ssl.Options | None = None, + ciphers: str | None = None, ) -> Socket: end = None if timeout is None: @@ -241,7 +241,7 @@ def create_tcp_connection( timeout = module.getdefaulttimeout() if timeout is not None: end = time.monotonic() + timeout - sock: Optional[Socket] = None + sock: Socket | None = None while True: timeout_at = end if end is None else end - time.monotonic() @@ -398,7 +398,7 @@ def selector_select( rlist: list[Any], wlist: list[Any], xlist: list[Any], - timeout: Optional[float] = None, + timeout: float | None = None, selectors_module: ModuleType = selectors, ) -> tuple[list[int], list[int], list[int]]: """Selector-based drop-in replacement for select to overcome select diff --git a/kazoo/hosts.py b/kazoo/hosts.py index 34e2c3a15..cda746a3a 100644 --- a/kazoo/hosts.py +++ b/kazoo/hosts.py @@ -1,12 +1,11 @@ from __future__ import annotations import urllib.parse -from typing import Optional, Union def collect_hosts( - hosts: Union[str, list[str]], -) -> tuple[list[tuple[str, int]], Optional[str]]: + hosts: str | list[str], +) -> tuple[list[tuple[str, int]], str | None]: """ Collect a set of hosts and an optional chroot from a string or a list of strings. diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 9e688d708..eedbc36d0 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -13,7 +13,7 @@ import abc import queue -from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING +from typing import Any, Callable, Protocol, TYPE_CHECKING if TYPE_CHECKING: from kazoo.protocol.states import Callback @@ -107,7 +107,7 @@ def select( rlist: list, wlist: list, xlist: list, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> tuple[list, list, list]: """A select method that implements Python's select.select API""" @@ -222,7 +222,7 @@ def set_exception(self, exception: Exception) -> None: block at all.""" @abc.abstractmethod - def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception :param block: Whether this method should block or return @@ -245,7 +245,7 @@ def get_nowait(self) -> Any: the associated :class:`IHandler` interface.""" @abc.abstractmethod - def wait(self, timeout: Optional[float] = None) -> Any: + def wait(self, timeout: float | None = None) -> Any: """Block until the instance is ready. :param timeout: How long to wait for a value when `block` is @@ -281,6 +281,6 @@ def unlink(self, callback: Callable[[IAsyncResult], Any]) -> None: @property @abc.abstractmethod - def exception(self) -> Optional[Exception]: + def exception(self) -> Exception | None: """The exception set by :meth:`set_exception` or `None` if no exception has been set""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 0c3ae30f4..4f9b88ab8 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -11,7 +11,7 @@ import socket import ssl import time -from typing import Any, Iterator, Literal, Optional, Union, TYPE_CHECKING +from typing import Any, Iterator, Literal, TYPE_CHECKING from kazoo.exceptions import ( AuthFailedError, @@ -110,17 +110,17 @@ def __init__( ): self.hosts = hosts self.connection = connection_func - self.last_attempt: Optional[float] = None + self.last_attempt: float | None = None self.socket_handling = socket_handling - def __iter__(self) -> Iterator[Union[tuple, None, bool]]: + def __iter__(self) -> Iterator[tuple | bool | None]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay: float) -> Union[tuple, None, bool]: + def _next_server(self, delay: float) -> tuple | bool | None: jitter = random.randint(0, 100) / 100.0 while ( time.monotonic() @@ -168,8 +168,8 @@ def __init__( self, client: KazooClient, retry_sleeper: Any, - logger: Optional[logging.Logger] = None, - sasl_options: Optional[dict] = None, + logger: logging.Logger | None = None, + sasl_options: dict | None = None, ): self.client = client self.handler = client.handler @@ -183,15 +183,15 @@ def __init__( self.connection_stopped.set() self.ping_outstanding = client.handler.event_object() - self._read_sock: Optional[Socket] = None - self._write_sock: Optional[Socket] = None + self._read_sock: Socket | None = None + self._write_sock: Socket | None = None - self._socket: Optional[Socket] = None - self._xid: Optional[int] = None - self._rw_server: Optional[tuple] = None - self._ro_mode: Optional[Union[Literal[False], Iterator]] = False + self._socket: Socket | None = None + self._xid: int | None = None + self._rw_server: tuple | None = None + self._ro_mode: Literal[False] | Iterator | None = False - self._connection_routine: Optional[Any] = None + self._connection_routine: Any | None = None self.sasl_options = sasl_options self.sasl_cli = None @@ -218,7 +218,7 @@ def start(self) -> None: ) self._connection_routine = self.handler.spawn(self.zk_loop) - def stop(self, timeout: Optional[float] = None) -> bool: + def stop(self, timeout: float | None = None) -> bool: """Ensure the writer has stopped, wait to see if it does.""" self.connection_stopped.wait(timeout) if self._connection_routine: @@ -250,14 +250,14 @@ def _server_pinger(self) -> RWPinger: self._socket_error_handling, ) - def _read_header(self, timeout: Optional[float]) -> tuple: + def _read_header(self, timeout: float | None) -> tuple: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) header, offset = ReplyHeader.deserialize(b, 0) return header, b, offset - def _read(self, length: int, timeout: Optional[float]) -> bytes: + def _read(self, length: int, timeout: float | None) -> bytes: msgparts = [] remaining = length # We know that self._socket is not None here because we only call @@ -303,9 +303,9 @@ def _read(self, length: int, timeout: Optional[float]) -> bytes: def _invoke( self, - timeout: Optional[float], + timeout: float | None, request: Any, - xid: Optional[int] = None, + xid: int | None = None, ) -> Any: """A special writer used during connection establishment only""" @@ -353,8 +353,8 @@ def _invoke( def _submit( self, request: Any, - timeout: Optional[float], - xid: Optional[int] = None, + timeout: float | None, + xid: int | None = None, ) -> None: """Submit a request object with a timeout value and optional xid""" @@ -372,7 +372,7 @@ def _submit( ) self._write(int_struct.pack(len(b)) + b, timeout) - def _write(self, msg: bytes, timeout: Optional[float]) -> None: + def _write(self, msg: bytes, timeout: float | None) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) @@ -439,7 +439,7 @@ def _read_response( header: Any, buffer: bytes, offset: int, - ) -> Optional[object]: + ) -> object | None: client = self.client request, async_object, xid = client._pending.popleft() if header.zxid and header.zxid > 0: @@ -509,7 +509,7 @@ def _read_response( return CLOSE_RESPONSE return None - def _read_socket(self, read_timeout: float) -> Optional[object]: + def _read_socket(self, read_timeout: float) -> object | None: """Called when there's something to read on the socket""" client = self.client diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 4311d4e42..274292613 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -3,7 +3,7 @@ import struct from collections import namedtuple -from typing import Any, ClassVar, Optional, Sequence, Union, TYPE_CHECKING +from typing import Any, ClassVar, Sequence, TYPE_CHECKING from kazoo.exceptions import EXCEPTIONS from kazoo.protocol.states import ZnodeStat @@ -47,7 +47,7 @@ def read_acl(bytes: bytes, offset: int) -> tuple: return ACL(perms, Id(scheme, id)), offset -def write_string(bytes: Optional[str]) -> bytes: +def write_string(bytes: str | None) -> bytes: if not bytes: return int_struct.pack(-1) else: @@ -55,7 +55,7 @@ def write_string(bytes: Optional[str]) -> bytes: return int_struct.pack(len(utf8_str)) + utf8_str -def write_buffer(bytes: Optional[bytes]) -> bytes: +def write_buffer(bytes: bytes | None) -> bytes: if bytes is None: return int_struct.pack(-1) else: @@ -109,7 +109,7 @@ class Connect( passwd: bytes read_only: bool - type: Optional[int] = None # Note: Not a classvar + type: int | None = None # Note: Not a classvar def serialize(self) -> bytearray: b = bytearray() @@ -146,7 +146,7 @@ def deserialize(cls, bytes: bytes, offset: int) -> tuple[Any, int]: class Create(namedtuple("Create", "path data acl flags")): path: str - data: Optional[bytes] + data: bytes | None acl: Sequence[ACL] flags: int @@ -190,7 +190,7 @@ def deserialize(cls, bytes: bytes, offset: int) -> bool: class Exists(namedtuple("Exists", "path watcher")): path: str - watcher: Optional[WatchFunc] + watcher: WatchFunc | None type: ClassVar[int] = 3 @@ -201,14 +201,14 @@ def serialize(self) -> bytearray: return b @classmethod - def deserialize(cls, bytes: bytes, offset: int) -> Optional[ZnodeStat]: + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat | None: stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return stat if stat.czxid != -1 else None class GetData(namedtuple("GetData", "path watcher")): path: str - watcher: Optional[WatchFunc] + watcher: WatchFunc | None type: ClassVar[int] = 4 @@ -221,7 +221,7 @@ def serialize(self) -> bytearray: @classmethod def deserialize( cls, bytes: bytes, offset: int - ) -> tuple[Optional[bytes], ZnodeStat]: + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat @@ -229,7 +229,7 @@ def deserialize( class SetData(namedtuple("SetData", "path data version")): path: str - data: Optional[bytes] + data: bytes | None version: int type: ClassVar[int] = 5 @@ -257,7 +257,7 @@ def serialize(self) -> bytearray: @classmethod def deserialize( cls, bytes: bytes, offset: int - ) -> Union[tuple[list[ACL], ZnodeStat], list[ACL]]: + ) -> tuple[list[ACL], ZnodeStat] | list[ACL]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -298,7 +298,7 @@ def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: class GetChildren(namedtuple("GetChildren", "path watcher")): path: str - watcher: Optional[WatchFunc] + watcher: WatchFunc | None type: ClassVar[int] = 8 @@ -337,7 +337,7 @@ def deserialize(cls, buffer: bytes, offset: int) -> str: class GetChildren2(namedtuple("GetChildren2", "path watcher")): path: str - watcher: Optional[WatchFunc] + watcher: WatchFunc | None type: ClassVar[int] = 12 @@ -350,7 +350,7 @@ def serialize(self) -> bytearray: @classmethod def deserialize( cls, bytes: bytes, offset: int - ) -> Union[tuple[list[str], ZnodeStat], list[str]]: + ) -> tuple[list[str], ZnodeStat] | list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -427,7 +427,7 @@ def unchroot(client: KazooClient, response: list[Any]) -> list[Any]: class Create2(namedtuple("Create2", "path data acl flags")): path: str - data: Optional[bytes] + data: bytes | None acl: Sequence[ACL] flags: int @@ -457,9 +457,9 @@ def deserialize(cls, bytes: bytes, offset: int) -> tuple[str, ZnodeStat]: class Reconfig( namedtuple("Reconfig", "joining leaving new_members config_id") ): - joining: Optional[str] - leaving: Optional[str] - new_members: Optional[str] + joining: str | None + leaving: str | None + new_members: str | None config_id: int type: ClassVar[int] = 16 @@ -475,7 +475,7 @@ def serialize(self) -> bytearray: @classmethod def deserialize( cls, bytes: bytes, offset: int - ) -> tuple[Optional[bytes], ZnodeStat]: + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat @@ -497,7 +497,7 @@ def serialize(self) -> bytes: class SASL(namedtuple("SASL", "challenge")): - challenge: Optional[bytes] + challenge: bytes | None type: ClassVar[int] = 102 @@ -509,7 +509,7 @@ def serialize(self) -> bytearray: @classmethod def deserialize( cls, bytes: bytes, offset: int - ) -> tuple[Optional[bytes], int]: + ) -> tuple[bytes | None, int]: challenge, offset = read_buffer(bytes, offset) return challenge, offset @@ -546,9 +546,9 @@ def deserialize(cls, bytes: bytes, offset: int) -> tuple[ReplyHeader, int]: class MultiHeader(namedtuple("MultiHeader", "type, done, err")): - type: Optional[int] + type: int | None done: bool - err: Optional[int] + err: int | None def serialize(self) -> bytearray: b = bytearray() diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 5f2a23e60..2058d2576 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum -from typing import Callable, NamedTuple, Optional +from typing import Callable, NamedTuple # This is a (str, Enum) for backwards compatibility. @@ -149,7 +149,7 @@ class WatchedEvent(NamedTuple): type: EventType state: KeeperState - path: Optional[str] + path: str | None class Callback(NamedTuple): @@ -253,7 +253,7 @@ def last_modified(self) -> float: return self.mtime / 1000.0 @property - def owner_session_id(self) -> Optional[int]: + def owner_session_id(self) -> int | None: return self.ephemeralOwner or None @property diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index 9af59e1a5..efe6cd053 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -10,7 +10,7 @@ import os import socket import uuid -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import Any, Literal, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError from kazoo.protocol.states import EventType @@ -61,7 +61,7 @@ def remove(self) -> bool: except NoNodeError: return False - def wait(self, timeout: Optional[float] = None) -> bool: + def wait(self, timeout: float | None = None) -> bool: """Wait on the barrier to be cleared :returns: True if the barrier has been cleared, otherwise @@ -105,7 +105,7 @@ def __init__( client: KazooClient, path: str, num_clients: int, - identifier: Optional[str] = None, + identifier: str | None = None, ): """Create a Double Barrier diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 7bf33c0b5..170cd884f 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -18,7 +18,7 @@ import functools import logging import operator -from typing import Any, Callable, Generator, Optional, Protocol, TYPE_CHECKING +from typing import Any, Callable, Generator, Protocol, TYPE_CHECKING from kazoo.exceptions import NoNodeError, KazooException @@ -152,8 +152,8 @@ def listen_fault( return listener def get_data( - self, path: str, default: Optional[NodeData] = None - ) -> Optional[NodeData]: + self, path: str, default: NodeData | None = None + ) -> NodeData | None: """Gets data of a node from cache. :param path: The absolute path string. @@ -166,8 +166,8 @@ def get_data( return default if node is None else node._data def get_children( - self, path: str, default: Optional[frozenset[str]] = None - ) -> Optional[frozenset[str]]: + self, path: str, default: frozenset[str] | None = None + ) -> frozenset[str] | None: """Gets node children list from in-memory snapshot. :param path: The absolute path string. @@ -179,7 +179,7 @@ def get_children( node = self._find_node(path) return default if node is None else frozenset(node._children) - def _find_node(self, path: str) -> Optional[TreeNode]: + def _find_node(self, path: str) -> TreeNode | None: if not path.startswith(self._root._path): raise ValueError("outside of tree") striped_path = path[len(self._root._path) :].strip("/") @@ -232,7 +232,7 @@ def _session_watcher(self, state: Any) -> None: class AsyncWatcher(Protocol): - def __call__(self, path: str, watch: Optional[WatchFunc]) -> IAsyncResult: + def __call__(self, path: str, watch: WatchFunc | None) -> IAsyncResult: ... @@ -258,14 +258,14 @@ class TreeNode(object): STATE_LIVE = 1 STATE_DEAD = 2 - def __init__(self, tree: TreeCache, path: str, parent: Optional[TreeNode]): + def __init__(self, tree: TreeCache, path: str, parent: TreeNode | None): self._tree = tree self._path = path self._parent = parent self._depth: int = parent._depth + 1 if parent is not None else 0 self._children: dict[str, TreeNode] = {} self._state = self.STATE_PENDING - self._data: Optional[NodeData] = None + self._data: NodeData | None = None @classmethod def make_root(cls, tree: TreeCache, path: str) -> TreeNode: diff --git a/kazoo/recipe/counter.py b/kazoo/recipe/counter.py index 1be44902d..531322522 100644 --- a/kazoo/recipe/counter.py +++ b/kazoo/recipe/counter.py @@ -8,7 +8,7 @@ from __future__ import annotations import struct -from typing import Optional, Union, TYPE_CHECKING +from typing import Union, TYPE_CHECKING from kazoo.exceptions import BadVersionError from kazoo.retry import ForceRetryError @@ -91,8 +91,8 @@ def __init__( self.default_type = type(default) self.support_curator = support_curator self._ensured_path = False - self.pre_value: Optional[Number] = None - self.post_value: Optional[Number] = None + self.pre_value: Number | None = None + self.post_value: Number | None = None if self.support_curator and not isinstance(self.default, int): raise TypeError( "when support_curator is enabled the default " diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 82a32cd3c..09fe43349 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from kazoo.exceptions import CancelledError @@ -34,7 +34,7 @@ def __init__( self, client: KazooClient, path: str, - identifier: Optional[str] = None, + identifier: str | None = None, ): """Create a Kazoo Leader Election diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index b43d00c70..1687b61c6 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -11,7 +11,7 @@ import datetime import json import socket -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from kazoo.exceptions import CancelledError @@ -58,7 +58,7 @@ def __init__( client: KazooClient, path: str, duration: datetime.timedelta, - identifier: Optional[str] = None, + identifier: str | None = None, utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): """Create a non-blocking lease. @@ -158,7 +158,7 @@ def __init__( count: int, path: str, duration: datetime.timedelta, - identifier: Optional[str] = None, + identifier: str | None = None, utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): self.obtained = False diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 241aea86c..29ce5cd91 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -24,9 +24,7 @@ Any, Iterable, Literal, - Optional, Pattern, - Union, TYPE_CHECKING, ) @@ -48,14 +46,14 @@ class _Watch(object): - def __init__(self, duration: Optional[float] = None): + def __init__(self, duration: float | None = None): self.duration = duration - self.started_at: Optional[float] = None + self.started_at: float | None = None def start(self) -> None: self.started_at = time.monotonic() - def leftover(self) -> Optional[float]: + def leftover(self) -> float | None: if self.duration is None: return None else: @@ -103,7 +101,7 @@ def __init__( self, client: KazooClient, path: str, - identifier: Optional[str] = None, + identifier: str | None = None, extra_lock_patterns: Iterable[str] = (), ): """Create a Kazoo lock. @@ -137,7 +135,7 @@ def __init__( # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock self.data = str(identifier or "").encode("utf-8") - self.node: Optional[str] = None + self.node: str | None = None self.wake_event = client.handler.event_object() @@ -169,7 +167,7 @@ def cancel(self) -> None: def acquire( self, blocking: bool = True, - timeout: Optional[float] = None, + timeout: float | None = None, ephemeral: bool = True, ) -> bool: """ @@ -244,7 +242,7 @@ def _watch_session(self, state: Any) -> bool: def _inner_acquire( self, blocking: bool, - timeout: Optional[float], + timeout: float | None, ephemeral: bool = True, ) -> bool: # wait until it's our chance to get it.. @@ -257,7 +255,7 @@ def _inner_acquire( if not self.assured_path: self._ensure_path() - node: Optional[str] = None + node: str | None = None if self.create_tried: node = self._find_node() else: @@ -306,7 +304,7 @@ def _inner_acquire( def _watch_predecessor(self, event: Any) -> None: self.wake_event.set() - def _get_predecessor(self, node: str) -> Optional[str]: + def _get_predecessor(self, node: str) -> str | None: """returns `node`'s predecessor or None Note: This handle the case where the current lock is not a contender @@ -315,7 +313,7 @@ def _get_predecessor(self, node: str) -> Optional[str]: """ node_sequence = node[len(self.prefix) :] children = self.client.get_children(self.path) - found_self: Union[Literal[False], None, re.Match[str]] = False + found_self: Literal[False] | re.Match[str] | None = False # Filter out the contenders using the computed regex contender_matches = [] for child in children: @@ -346,7 +344,7 @@ def _get_predecessor(self, node: str) -> Optional[str]: sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) return sorted_matches[-1].string - def _find_node(self) -> Optional[str]: + def _find_node(self) -> str | None: children = self.client.get_children(self.path) for child in children: if child.startswith(self.prefix): @@ -541,7 +539,7 @@ def __init__( self, client: KazooClient, path: str, - identifier: Optional[str] = None, + identifier: str | None = None, max_leases: int = 1, ): """Create a Kazoo Lock @@ -608,7 +606,7 @@ def cancel(self) -> None: def acquire( self, blocking: bool = True, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> bool: """Acquire the semaphore. By defaults blocks and waits forever. @@ -650,7 +648,7 @@ def acquire( def _inner_acquire( self, blocking: bool, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> bool: """Inner loop that runs from the top anytime a command hits a retryable Zookeeper exception.""" @@ -729,7 +727,7 @@ def _get_lease(self, data: Any = None) -> bool: # Return current state return self.is_acquired - def _watch_session(self, state: Any) -> Optional[bool]: + def _watch_session(self, state: Any) -> bool | None: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 8d916fd4e..71878b606 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -25,7 +25,7 @@ import os import socket from enum import Enum -from typing import Any, Callable, Iterator, Optional, Sequence, TYPE_CHECKING +from typing import Any, Callable, Iterator, Sequence, TYPE_CHECKING from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState @@ -152,13 +152,12 @@ def __init__( client: KazooClient, path: str, set: Sequence[str], - partition_func: Optional[ - Callable[[str, list[str], Sequence[str]], list[str]] - ] = None, - identifier: Optional[str] = None, + partition_func: Callable[[str, list[str], Sequence[str]], list[str]] + | None = None, + identifier: str | None = None, time_boundary: float = 30, max_reaction_time: float = 1, - state_change_event: Optional[Any] = None, + state_change_event: Any | None = None, ): """Create a :class:`~SetPartitioner` instance @@ -405,7 +404,7 @@ def _abort_lock_acquisition(self) -> None: def _child_watching( self, - func: Optional[Callable[..., Any]] = None, + func: Callable[..., Any] | None = None, client_handler: bool = False, ) -> Any: """Called when children are being watched to stabilize diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index baf517f27..f28f2be11 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -11,7 +11,7 @@ from __future__ import annotations import uuid -from typing import Any, Iterator, Optional, TYPE_CHECKING +from typing import Any, Iterator, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError @@ -26,7 +26,7 @@ def __init__( self, client: KazooClient, path: str, - identifier: Optional[str] = None, + identifier: str | None = None, ): """ :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -56,7 +56,8 @@ def _inner_join(self) -> None: try: # This and the #type: ignore[attr-defined] below could be removed # by setting up create_path in the constructor but trying to avoid - # changing the code too much + # changing the code too much. It does actually cause later versions + # of pylint to error though. self.client.create( self.create_path, # type: ignore[attr-defined] self.data, @@ -95,7 +96,7 @@ class Party(BaseParty): _NODE_NAME = "__party__" def __init__( - self, client: KazooClient, path: str, identifier: Optional[str] = None + self, client: KazooClient, path: str, identifier: str | None = None ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = uuid.uuid4().hex + self._NODE_NAME @@ -134,7 +135,7 @@ def __init__( self, client: KazooClient, path: str, - identifier: Optional[str] = None, + identifier: str | None = None, ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = "-".join([uuid.uuid4().hex, self.data.decode("utf-8")]) diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index c79d2dd71..dae3ed5fd 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -13,7 +13,7 @@ from __future__ import annotations import uuid -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError from kazoo.protocol.states import EventType @@ -81,7 +81,7 @@ def __len__(self) -> int: """Return queue size.""" return super(Queue, self).__len__() - def get(self) -> Optional[bytes]: + def get(self) -> bytes | None: """ Get item data and remove an item from the queue. @@ -91,7 +91,7 @@ def get(self) -> Optional[bytes]: self._ensure_paths() return self.client.retry(self._inner_get) - def _inner_get(self) -> Optional[bytes]: + def _inner_get(self) -> bytes | None: if not self._children: self._children = self.client.retry( self.client.get_children, self.path @@ -164,7 +164,7 @@ def __init__(self, client: KazooClient, path: str): """ super(LockingQueue, self).__init__(client, path) self.id = uuid.uuid4().hex.encode() - self.processing_element: Optional[tuple[str, bytes]] = None + self.processing_element: tuple[str, bytes] | None = None self._lock_path = self.path + self.lock self._entries_path = self.path + self.entries self.structure_paths = (self._lock_path, self._entries_path) @@ -228,7 +228,7 @@ def put_all(self, values: list[bytes], priority: int = 100) -> None: sequence=True, ) - def get(self, timeout: Optional[float] = None) -> Optional[bytes]: + def get(self, timeout: float | None = None) -> bytes | None: """Locks and gets an entry from the queue. If a previously got entry was not consumed, this method will return that entry. @@ -296,13 +296,13 @@ def release(self) -> bool: else: return False - def _inner_get(self, timeout: Optional[float]) -> Optional[bytes]: + def _inner_get(self, timeout: float | None) -> bytes | None: flag = self.client.handler.event_object() lock = self.client.handler.lock_object() canceled = False value = [] - def check_for_updates(event: Optional[Any]) -> None: + def check_for_updates(event: Any | None) -> None: if event is not None and event.type != EventType.CHILD: return with lock: @@ -346,7 +346,7 @@ def _filter_locked(self, values: list[str], taken: list[str]) -> list[str]: else [x for x in available if x not in taken] ) - def _take(self, id_: str) -> Optional[tuple[str, bytes]]: + def _take(self, id_: str) -> tuple[str, bytes] | None: try: self.client.create( "{path}/{id}".format(path=self._lock_path, id=id_), diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index d45e01acb..32f38042f 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -108,7 +108,7 @@ def __init__( self, client: KazooClient, path: str, - func: Optional[DataWatchFunc] = None, + func: DataWatchFunc | None = None, *args: Any, **kwargs: Any, ): @@ -130,7 +130,7 @@ def __init__( self._func = func self._stopped = False self._run_lock = client.handler.lock_object() - self._version: Optional[int] = None + self._version: int | None = None self._retry = KazooRetry( max_tries=None, sleep_func=client.handler.sleep_func ) @@ -181,8 +181,8 @@ def __call__(self, func: DataWatchFunc) -> DataWatchFunc: def _log_func_exception( self, data: Any, - stat: Optional[ZnodeStat], - event: Optional[WatchedEvent] = None, + stat: ZnodeStat | None, + event: WatchedEvent | None = None, ) -> None: try: # For backwards compatibility, don't send event to the @@ -208,7 +208,7 @@ def _log_func_exception( raise @_ignore_closed - def _get_data(self, event: Optional[WatchedEvent] = None) -> None: + def _get_data(self, event: WatchedEvent | None = None) -> None: # Ensure this runs one at a time, possible because the session # watcher may trigger a run with self._run_lock: @@ -217,7 +217,7 @@ def _get_data(self, event: Optional[WatchedEvent] = None) -> None: initial_version = self._version - stat: Optional[ZnodeStat] + stat: ZnodeStat | None try: data, stat = self._retry( self._client.get, self._path, self._watcher @@ -296,7 +296,7 @@ def __init__( self, client: KazooClient, path: str, - func: Optional[ChildrenWatchFunc] = None, + func: ChildrenWatchFunc | None = None, allow_session_lost: bool = True, send_event: bool = False, ): @@ -331,7 +331,7 @@ def __init__( self._watch_established = False self._allow_session_lost = allow_session_lost self._run_lock = client.handler.lock_object() - self._prior_children: Optional[List[str]] = None + self._prior_children: list[str] | None = None self._used = False # Register our session listener if we're going to resume @@ -366,7 +366,7 @@ def __call__(self, func: ChildrenWatchFunc) -> ChildrenWatchFunc: return func @_ignore_closed - def _get_children(self, event: Optional[WatchedEvent] = None) -> None: + def _get_children(self, event: WatchedEvent | None = None) -> None: with self._run_lock: # Ensure this runs one at a time if self._stopped: return diff --git a/kazoo/retry.py b/kazoo/retry.py index c8e0bcc2b..d406696ff 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -3,7 +3,7 @@ import logging import random import time -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Callable, TypeVar from kazoo.exceptions import ( ConnectionClosedError, @@ -43,15 +43,15 @@ class KazooRetry(object): def __init__( self, - max_tries: Optional[int] = 1, + max_tries: int | None = 1, delay: float = 0.1, backoff: int = 2, max_jitter: float = 0.4, max_delay: float = 60.0, ignore_expire: bool = True, sleep_func: Callable[[float], None] = time.sleep, - deadline: Optional[float] = None, - interrupt: Optional[Callable[[], bool]] = None, + deadline: float | None = None, + interrupt: Callable[[], bool] | None = None, ) -> None: """Create a :class:`KazooRetry` instance for retrying function calls. @@ -85,7 +85,7 @@ def __init__( self._attempts = 0 self._cur_delay = delay self.deadline = deadline - self._cur_stoptime: Optional[float] = None + self._cur_stoptime: float | None = None self.sleep_func = sleep_func self.retry_exceptions: tuple[ type[Exception], ... From 84c0766965b6de19e6e33277924d156427f97232 Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Wed, 29 Apr 2026 23:00:05 +0100 Subject: [PATCH 3/8] Found some more dead python2 --- kazoo/recipe/lease.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index 1687b61c6..0d8c40463 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -123,11 +123,6 @@ def _encode(self, data_dict: dict[str, Any]) -> bytes: def _decode(self, raw: bytes) -> dict[str, Any]: return json.loads(raw.decode(self._byte_encoding)) - # Python 2.x - def __nonzero__(self) -> bool: - return self.obtained - - # Python 3.x def __bool__(self) -> bool: return self.obtained @@ -174,10 +169,5 @@ def __init__( self.obtained = True break - # Python 2.x - def __nonzero__(self) -> bool: - return self.obtained - - # Python 3.x def __bool__(self) -> bool: return self.obtained From 81f937d0eb1326216d65804a3c11a573d8a7f17a Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Thu, 30 Apr 2026 21:13:05 +0100 Subject: [PATCH 4/8] tighten the screw --- .gitignore | 2 +- kazoo/client.py | 97 ++++++++++++++++++++------------- kazoo/exceptions.py | 16 ++++-- kazoo/handlers/eventlet.py | 75 ++++++++++++++++--------- kazoo/handlers/gevent.py | 53 +++++++++++++----- kazoo/handlers/threading.py | 46 +++++++++++----- kazoo/handlers/utils.py | 2 +- kazoo/interfaces.py | 96 ++++++++++++++++++++++++++++---- kazoo/protocol/connection.py | 73 ++++++++++++++++++------- kazoo/protocol/serialization.py | 24 +++++--- kazoo/protocol/states.py | 6 +- kazoo/recipe/cache.py | 6 +- kazoo/recipe/lease.py | 6 +- kazoo/recipe/partitioner.py | 2 +- pyproject.toml | 19 ++----- 15 files changed, 368 insertions(+), 155 deletions(-) diff --git a/.gitignore b/.gitignore index 94c5a4e2e..8df69daac 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,6 @@ __pycache__/ !.gitignore !.git-blame-ignore-revs -.vscode/settings.json +.vscode/ .*_cache/ coverage.xml diff --git a/kazoo/client.py b/kazoo/client.py index 1e195b249..1bd5ae80b 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -10,13 +10,15 @@ import re import warnings from typing import ( + cast, + overload, Any, Callable, + Deque, Literal, Optional, Sequence, Set, - overload, TYPE_CHECKING, ) @@ -127,15 +129,16 @@ def __init__( self, hosts: str | list[str] = "127.0.0.1:2181", timeout: float = 10.0, - client_id: tuple | None = None, + client_id: tuple[int | None, bytes] | None = None, handler: IHandler | None = None, default_acl: Sequence[ACL] | None = None, - auth_data: set | None = None, - sasl_options: dict | None = None, + auth_data: set[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, read_only: bool | None = None, randomize_hosts: bool = True, - connection_retry: KazooRetry | dict | None = None, - command_retry: KazooRetry | dict | None = None, + # FIXME the dict should be a TypeDict + connection_retry: KazooRetry | dict[str, Any] | None = None, + command_retry: KazooRetry | dict[str, Any] | None = None, logger: logging.Logger | None = None, keyfile: str | None = None, keyfile_password: str | None = None, @@ -459,8 +462,8 @@ def _retry(*args: Any, **kwargs: Any) -> Any: def _reset(self) -> None: """Resets a variety of client states for a new connection.""" - self._queue: deque = deque() - self._pending: deque = deque() + self._queue: Deque[tuple[Any, IAsyncResult]] = deque() + self._pending: Deque[tuple[Any, IAsyncResult, int]] = deque() self._reset_watchers() self._reset_session() @@ -497,7 +500,7 @@ def client_state(self) -> KeeperState: return self._state @property - def client_id(self) -> tuple | None: + def client_id(self) -> tuple[Any, Any] | None: """Returns the client id for this Zookeeper session if connected. @@ -790,7 +793,7 @@ def stop(self) -> None: return self._stopped.set() - self._queue.append((CloseInstance, None)) + self._queue.append((CloseInstance, cast("IAsyncResult", None))) try: # This assert should never fail since the connection should # have been started but I'm not sure how to persaude mypy of that @@ -863,7 +866,7 @@ def command(self, cmd: bytes = b"ruok") -> str: sock.close() return result.decode("utf-8", "replace") - def server_version(self, retries: int = 3) -> tuple: + def server_version(self, retries: int = 3) -> tuple[int, ...]: """Get the version of the currently connected ZK server. :returns: The server version, for example (3, 4, 3). @@ -887,8 +890,8 @@ def _try_fetch() -> tuple[int, ...] | None: if k: data_parsed[k] = v version = data_parsed.get(ENVI_VERSION_KEY, "") - # a) if you get an unexpected answer, you'll crash - # b) not changing the code, so just ignoring the type error + # FIXME If you get an unexpected answer, you'll crash - not + # changing the code, so just ignoring the type error version_digits = ENVI_VERSION.match( version ).group( # type: ignore[union-attr] @@ -907,8 +910,9 @@ def _is_valid(version: tuple[int, ...] | None) -> bool: return True return False - # A better way of doing this would be to put the initial _try_fetch in - # the loop and inline _is_valid but I want to minimise code changes + # FIXME A better way of doing this would be to put the initial + # _try_fetch in the loop and inline _is_valid but I want to minimise + # code changes # Try 1 + retries amount of times to get a version that we know # will likely be acceptable... @@ -943,7 +947,7 @@ def add_auth(self, scheme: str, credential: str) -> bool: the session state will be set to AUTH_FAILED as well. """ - return self.add_auth_async(scheme, credential).get() + return cast("bool", self.add_auth_async(scheme, credential).get()) def add_auth_async(self, scheme: str, credential: str) -> IAsyncResult: """Asynchronously send credentials to server. Takes the same @@ -1009,7 +1013,7 @@ def sync(self, path: str) -> str: .. versionadded:: 0.5 """ - return self.sync_async(path).get() + return cast("str", self.sync_async(path).get()) @overload def create( @@ -1125,15 +1129,18 @@ def create( The `include_data` option. """ acl = acl or self.default_acl - return self.create_async( - path, - value, - acl=acl, - ephemeral=ephemeral, - sequence=sequence, - makepath=makepath, - include_data=include_data, - ).get() + return cast( + "str | tuple[str, ZnodeStat]", + self.create_async( + path, + value, + acl=acl, + ephemeral=ephemeral, + sequence=sequence, + makepath=makepath, + include_data=include_data, + ).get(), + ) def create_async( self, @@ -1272,7 +1279,7 @@ def ensure_path(self, path: str, acl: Sequence[ACL] | None = None) -> bool: :param acl: Permissions for node. """ - return self.ensure_path_async(path, acl).get() + return cast("bool", self.ensure_path_async(path, acl).get()) def ensure_path_async( self, path: str, acl: Sequence[ACL] | None = None @@ -1289,9 +1296,9 @@ def ensure_path_async( async_result = self.handler.async_result() @wrap(async_result) - def create_completion(result: Any) -> bool: + def create_completion(result: IAsyncResult) -> bool: try: - return result.get() + return cast("bool", result.get()) except NodeExistsError: return True @@ -1341,7 +1348,9 @@ def exists( returns a non-zero error code. """ - return self.exists_async(path, watch=watch).get() + return cast( + "ZnodeStat | None", self.exists_async(path, watch=watch).get() + ) def exists_async( self, path: str, watch: WatchFunc | None = None @@ -1388,7 +1397,9 @@ def get( returns a non-zero error code """ - return self.get_async(path, watch=watch).get() + return cast( + "tuple[bytes, ZnodeStat]", self.get_async(path, watch=watch).get() + ) def get_async( self, path: str, watch: WatchFunc | None = None @@ -1449,9 +1460,12 @@ def get_children( The `include_data` option. """ - return self.get_children_async( - path, watch=watch, include_data=include_data - ).get() + return cast( + "list[str]", + self.get_children_async( + path, watch=watch, include_data=include_data + ).get(), + ) def get_children_async( self, @@ -1473,6 +1487,7 @@ def get_children_async( raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() + # FIXME? Do this as req = getc2 if include_data else getc req: GetChildren | GetChildren2 if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) @@ -1499,7 +1514,9 @@ def get_acls(self, path: str) -> tuple[list[ACL], ZnodeStat]: .. versionadded:: 0.5 """ - return self.get_acls_async(path).get() + return cast( + "tuple[list[ACL], ZnodeStat]", self.get_acls_async(path).get() + ) def get_acls_async(self, path: str) -> IAsyncResult: """Return the ACL and stat of the node of the given path. Takes @@ -1544,7 +1561,9 @@ def set_acls( .. versionadded:: 0.5 """ - return self.set_acls_async(path, acls, version).get() + return cast( + "ZnodeStat", self.set_acls_async(path, acls, version).get() + ) def set_acls_async( self, path: str, acls: Sequence[ACL], version: int = -1 @@ -1606,7 +1625,7 @@ def set( returns a non-zero error code. """ - return self.set_async(path, value, version).get() + return cast("ZnodeStat", self.set_async(path, value, version).get()) def set_async( self, path: str, value: bytes | None, version: int = -1 @@ -1805,7 +1824,7 @@ def reconfig( result = self.reconfig_async( joining, leaving, new_members, from_config ) - return result.get() + return cast("tuple[bytes, ZnodeStat]", result.get()) def reconfig_async( self, @@ -1970,7 +1989,7 @@ def commit(self) -> list[Any]: transaction. """ - return self.commit_async().get() + return cast("list[Any]", self.commit_async().get()) def __enter__(self) -> TransactionRequest: return self diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index 2bfa7d66e..cb943b47f 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -58,13 +58,21 @@ def _invalid_error_code() -> Any: raise RuntimeError("Invalid error code") -EXCEPTIONS: defaultdict = defaultdict(_invalid_error_code) +EXCEPTIONS: defaultdict[int, Type[ZookeeperError]] = defaultdict( + _invalid_error_code +) -def _zookeeper_exception(code: int) -> Callable[[Type[Any]], Type[Any]]: - def decorator(klass: Type[Any]) -> Type[Any]: +def _zookeeper_exception( + code: int, +) -> Callable[[Type[ZookeeperError]], Type[ZookeeperError]]: + def decorator(klass: Type[ZookeeperError]) -> Type[ZookeeperError]: EXCEPTIONS[code] = klass - klass.code = code + # Unfortunately there is currently no good of doing the assignment here + # in a way that type checkers would allow. It's a known problem (see + # https://discuss.python.org/t/how-to-type-hint-a-class-decorator/63010 + # ) + klass.code = code # type: ignore[attr-defined] return klass return decorator diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 028a18bfd..3cefeb546 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -7,7 +7,7 @@ import contextlib import logging -from typing import Any, Generator, TYPE_CHECKING +from typing import cast, Any, Generator, TYPE_CHECKING import eventlet from eventlet.green import socket as green_socket @@ -20,7 +20,8 @@ from kazoo.handlers.utils import selector_select if TYPE_CHECKING: - from kazoo.interfaces import Socket + from kazoo.interfaces import Event, Lockable, ReentrantLock, Socket + from kazoo.protocol.states import Callback LOG = logging.getLogger(__name__) @@ -36,11 +37,11 @@ def _yield_before_after() -> Generator[None, None, None]: # See: http://eventlet.net/doc/modules/greenthread.html # for how this zero sleep is really a cooperative yield to other potential # co-routines... - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] try: yield finally: - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] class TimeoutError(Exception): @@ -52,7 +53,9 @@ class AsyncResult(utils.AsyncResult): def __init__(self, handler: Any): super(AsyncResult, self).__init__( - handler, green_threading.Condition, TimeoutError + handler, + green_threading.Condition, # type: ignore[attr-defined] + TimeoutError, ) @@ -91,16 +94,23 @@ class SequentialEventletHandler(object): def __init__(self) -> None: """Create a :class:`SequentialEventletHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() - self._workers: list[ - tuple[eventlet.GreenThread, green_queue.LightQueue] + self.callback_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self.completion_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self._workers: list[ # type: ignore[name-defined] + tuple[ + eventlet.GreenThread, + green_queue.LightQueue, + ] ] = [] self._started = False @staticmethod def sleep_func(wait: float) -> None: - green_time.sleep(wait) + green_time.sleep(wait) # type: ignore[attr-defined, no-untyped-call] @property def running(self) -> bool: @@ -110,7 +120,7 @@ def running(self) -> bool: def _process_completion_queue(self) -> None: while True: - cb = self.completion_queue.get() + cb = self.completion_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -126,7 +136,7 @@ def _process_completion_queue(self) -> None: def _process_callback_queue(self) -> None: while True: - cb = self.callback_queue.get() + cb = self.callback_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -145,9 +155,13 @@ def start(self) -> None: # Spawn our worker threads, we have # - A callback worker for watch events to be called # - A completion worker for completion events to be called - w = eventlet.spawn(self._process_completion_queue) + w = eventlet.spawn( + self._process_completion_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.completion_queue)) - w = eventlet.spawn(self._process_callback_queue) + w = eventlet.spawn( + self._process_callback_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.callback_queue)) self._started = True atexit.register(self.stop) @@ -155,7 +169,7 @@ def start(self) -> None: def stop(self) -> None: while self._workers: w, q = self._workers.pop() - q.put(_STOP) + q.put(_STOP) # type: ignore[no-untyped-call] w.wait() self._started = False atexit.unregister(self.stop) @@ -166,14 +180,21 @@ def socket(self, *args: Any, **kwargs: Any) -> Socket: def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(green_socket) - def event_object(self) -> green_threading.Event: - return green_threading.Event() + def event_object(self) -> Event: + return cast( + "Event", green_threading.Event() # type: ignore[attr-defined] + ) - def lock_object(self) -> green_threading.Lock: - return green_threading.Lock() + def lock_object(self) -> Lockable: + return cast( + "Lockable", green_threading.Lock() # type: ignore[attr-defined] + ) - def rlock_object(self) -> green_threading.RLock: - return green_threading.RLock() + def rlock_object(self) -> ReentrantLock: + return cast( + "ReentrantLock", + green_threading.RLock(), # type: ignore[attr-defined] + ) def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) @@ -195,11 +216,15 @@ def async_result(self) -> AsyncResult: def spawn( self, func: Any, *args: Any, **kwargs: Any - ) -> green_threading.Thread: - t = green_threading.Thread(target=func, args=args, kwargs=kwargs) + ) -> green_threading.Thread: # type: ignore[name-defined] + t = green_threading.Thread( # type: ignore[attr-defined] + target=func, args=args, kwargs=kwargs + ) t.daemon = True t.start() return t - def dispatch_callback(self, callback: Any) -> None: - self.callback_queue.put(lambda: callback.func(*callback.args)) + def dispatch_callback(self, callback: Callback) -> None: + self.callback_queue.put( # type: ignore[no-untyped-call] + lambda: callback.func(*callback.args) + ) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index ebc04b2e5..1576eef7f 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -6,7 +6,7 @@ import atexit import logging -from typing import Any, TYPE_CHECKING +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast import gevent from gevent import socket @@ -18,22 +18,30 @@ from kazoo.handlers.utils import selector_select +from gevent import Greenlet from gevent.lock import Semaphore, RLock from kazoo.handlers import utils if TYPE_CHECKING: - from gevent import Greenlet - from kazoo.interfaces import Socket + from kazoo.interfaces import HasFileNo, Lockable, Socket + from kazoo.protocol.states import Callback _using_libevent = gevent.__version__.startswith("0.") log = logging.getLogger(__name__) + _STOP = object() AsyncResult = gevent.event.AsyncResult +# The following would be great typenames, but python3.8 complains about type +# objects not being indexable. +# GCallback = Callable[..., None] +# Worker = Greenlet[..., Any] +# CallbackQueue = gevent.queue.Queue[Callable[..., None]] + class SequentialGeventHandler(object): """Gevent handler for sequentially executing callbacks. @@ -63,11 +71,13 @@ class SequentialGeventHandler(object): def __init__(self) -> None: """Create a :class:`SequentialGeventHandler` instance""" - self.callback_queue: gevent.queue.Queue = self.queue_impl() + self.callback_queue: gevent.queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._async = None self._state_change = Semaphore() - self._workers: list[Greenlet] = [] + self._workers: list[Greenlet[..., Any]] = [] @property def running(self) -> bool: @@ -77,7 +87,9 @@ class timeout_exception(gevent.Timeout): def __init__(self, msg: Any): gevent.Timeout.__init__(self, exception=msg) - def _create_greenlet_worker(self, queue: Any) -> gevent.Greenlet: + def _create_greenlet_worker( + self, queue: gevent.queue.Queue[Callable[..., None]] + ) -> Greenlet[..., Any]: def greenlet_worker() -> None: while True: try: @@ -106,6 +118,7 @@ def start(self) -> None: # Spawn our worker greenlets, we have # - A callback worker for watch events to be called + # FIXME Why the loop? for queue in (self.callback_queue,): w = self._create_greenlet_worker(queue) self._workers.append(w) @@ -120,7 +133,7 @@ def stop(self) -> None: self._running = False for queue in (self.callback_queue,): - queue.put(_STOP) + queue.put(cast("Callable[..., None]", _STOP)) while self._workers: worker = self._workers.pop() @@ -131,7 +144,14 @@ def stop(self) -> None: atexit.unregister(self.stop) - def select(self, *args: Any, **kwargs: Any) -> tuple: + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[ + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + ]: + # FIXME use the correct arguments, not *args, *kwargs return selector_select( # Likely a bug in mypy (see # https://github.com/python/mypy/issues/6799) @@ -141,9 +161,11 @@ def select(self, *args: Any, **kwargs: Any) -> tuple: ) def socket(self, *args: Any, **kwargs: Any) -> Socket: + # See above return utils.create_tcp_socket(socket) def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + # See above return utils.create_tcp_connection(socket, *args, **kwargs) def create_socket_pair(self) -> tuple[Socket, Socket]: @@ -153,15 +175,18 @@ def event_object(self) -> gevent.event.Event: """Create an appropriate Event object""" return gevent.event.Event() - def lock_object(self) -> Any: + def lock_object(self) -> Lockable: """Create an appropriate Lock object""" - return gevent.thread.allocate_lock() + return cast( + "Lockable", + gevent.thread.allocate_lock(), # type: ignore[no-untyped-call] + ) def rlock_object(self) -> RLock: """Create an appropriate RLock object""" return RLock() - def async_result(self) -> AsyncResult: + def async_result(self) -> AsyncResult[Any]: """Create a :class:`AsyncResult` instance The :class:`AsyncResult` instance will have its completion @@ -172,11 +197,13 @@ def async_result(self) -> AsyncResult: """ return AsyncResult() - def spawn(self, func: Any, *args: Any, **kwargs: Any) -> gevent.Greenlet: + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> gevent.Greenlet[..., Any]: """Spawn a function to run asynchronously""" return gevent.spawn(func, *args, **kwargs) - def dispatch_callback(self, callback: Any) -> None: + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index f112022f3..7699487c3 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -21,14 +21,22 @@ import threading import time -from typing import Any, TYPE_CHECKING +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast from kazoo.handlers import utils from kazoo.handlers.utils import selector_select from kazoo.interfaces import IHandler if TYPE_CHECKING: - from kazoo.interfaces import Socket, SpawnedFunc + from kazoo.interfaces import ( + Event, + HasFileNo, + Lockable, + ReentrantLock, + Socket, + SpawnedFunc, + ) + from kazoo.protocol.states import Callback # sentinel objects _STOP = object() @@ -105,8 +113,12 @@ class SequentialThreadingHandler(IHandler): def __init__(self) -> None: """Create a :class:`SequentialThreadingHandler` instance""" - self.callback_queue: queue.Queue = self.queue_impl() - self.completion_queue: queue.Queue = self.queue_impl() + self.callback_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() + self.completion_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._state_change = threading.Lock() self._workers: list[threading.Thread] = [] @@ -116,7 +128,7 @@ def running(self) -> bool: return self._running def _create_thread_worker( - self, work_queue: queue.Queue + self, work_queue: queue.Queue[Callable[..., None]] ) -> threading.Thread: def _thread_worker() -> None: # pragma: nocover while True: @@ -161,7 +173,7 @@ def stop(self) -> None: self._running = False for work_queue in (self.completion_queue, self.callback_queue): - work_queue.put(_STOP) + work_queue.put(cast("Callable[..., None]", _STOP)) self._workers.reverse() while self._workers: @@ -173,7 +185,13 @@ def stop(self) -> None: self.completion_queue = self.queue_impl() atexit.unregister(self.stop) - def select(self, *args: Any, **kwargs: Any) -> tuple: + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[ + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + ]: return selector_select(*args, **kwargs) def socket(self) -> Socket: @@ -185,17 +203,19 @@ def create_connection(self, *args: Any, **kwargs: Any) -> Socket: def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self) -> threading.Event: + def event_object(self) -> Event: """Create an appropriate Event object""" return threading.Event() - def lock_object(self) -> threading.Lock: + def lock_object(self) -> Lockable: """Create a lock object""" - return threading.Lock() + # Note: This is not ideal, but the ContextManager Protocol seems to + # think you should return an object of the same type. + return cast("Lockable", threading.Lock()) - def rlock_object(self) -> threading.RLock: + def rlock_object(self) -> ReentrantLock: """Create an appropriate RLock object""" - return threading.RLock() + return cast("ReentrantLock", threading.RLock()) def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance""" @@ -209,7 +229,7 @@ def spawn( t.start() return t - def dispatch_callback(self, callback: Any) -> None: + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index cbba1b9e1..46f017529 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -214,7 +214,7 @@ def create_tcp_socket(module: ModuleType) -> Socket: if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover # if available, set cloexec flag during socket creation type_ |= module.SOCK_CLOEXEC - sock = module.socket(module.AF_INET, type_) + sock: Socket = module.socket(module.AF_INET, type_) _set_default_tcpsock_options(module, sock) return sock diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index eedbc36d0..6751bf49e 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -13,7 +13,13 @@ import abc import queue -from typing import Any, Callable, Protocol, TYPE_CHECKING +from typing import ( + Any, + Callable, + Iterable, + Protocol, + TYPE_CHECKING, +) if TYPE_CHECKING: from kazoo.protocol.states import Callback @@ -21,7 +27,14 @@ # public API -class Socket(Protocol): +class HasFileNo(Protocol): + """Protocol for things like select""" + + def fileno(self) -> int: + ... + + +class Socket(HasFileNo, Protocol): """This is for things that provide a socket.socket-like interface. This is required because: @@ -52,6 +65,65 @@ def setsockopt(self, level: int, optname: int, value: int) -> None: ... +class Lockable(Protocol): + """This is what threading.Lock implements. + + In python 3.9+ it's available natively. Though given it has some + very odd typing, I wouldn't put money on it. + """ + + def __enter__(self) -> None: + ... + + def __exit__(self, x: Any, y: Any, z: Any) -> None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> int | None: + """The gevent release returns an int...""" + ... + + def locked(self) -> bool: + ... + + +class ReentrantLock(Protocol): + """This is what threading.RLock implements. + + In python 3.14+, it's the same as Lock, which adds to the fun. + """ + + def __enter__(self) -> None: + ... + + def __exit__(self, x: Any, y: Any, z: Any) -> None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> None: + ... + + +class Event(Protocol): + """Protocol for threading.Event""" + + def is_set(self) -> bool: + ... + + def set(self) -> None: + ... + + def clear(self) -> None: + ... + + def wait(self, timeout: float | None = None) -> bool: + ... + + SpawnedFunc = Callable[..., None] @@ -90,7 +162,7 @@ class IHandler(abc.ABC): timeout_exception: type[Exception] = None # type: ignore[assignment] sleep_func: staticmethod[[float], None] = None # type: ignore[assignment] - queue_impl: type[queue.Queue] = None # type: ignore[assignment] + queue_impl: type[queue.Queue[Any]] = None # type: ignore[assignment] @abc.abstractmethod def start(self) -> None: @@ -104,11 +176,15 @@ def stop(self) -> None: @abc.abstractmethod def select( self, - rlist: list, - wlist: list, - xlist: list, + rlist: Iterable[int | HasFileNo], + wlist: Iterable[int | HasFileNo], + xlist: Iterable[int | HasFileNo], timeout: float | None = None, - ) -> tuple[list, list, list]: + ) -> tuple[ + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + Iterable[int | HasFileNo], + ]: """A select method that implements Python's select.select API""" @@ -126,17 +202,17 @@ def create_socket_pair(self) -> tuple[Socket, Socket]: """A socket method that implements Python's socket.socketpair API""" @abc.abstractmethod - def event_object(self) -> Any: + def event_object(self) -> Event: """Return an appropriate object that implements Python's threading.Event API""" @abc.abstractmethod - def lock_object(self) -> Any: + def lock_object(self) -> Lockable: """Return an appropriate object that implements Python's threading.Lock API""" @abc.abstractmethod - def rlock_object(self) -> Any: + def rlock_object(self) -> ReentrantLock: """Return an appropriate object that implements Python's threading.RLock API""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 4f9b88ab8..c2e0d5089 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -11,7 +11,7 @@ import socket import ssl import time -from typing import Any, Iterator, Literal, TYPE_CHECKING +from typing import Any, Iterator, Literal, TYPE_CHECKING, cast from kazoo.exceptions import ( AuthFailedError, @@ -104,7 +104,7 @@ class RWPinger(object): def __init__( self, - hosts: Any, + hosts: list[tuple[str, int]], connection_func: Any, socket_handling: Any, ): @@ -113,14 +113,16 @@ def __init__( self.last_attempt: float | None = None self.socket_handling = socket_handling - def __iter__(self) -> Iterator[tuple | bool | None]: + def __iter__(self) -> Iterator[tuple[str, int] | Literal[False] | None]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay: float) -> tuple | bool | None: + def _next_server( + self, delay: float + ) -> tuple[str, int] | Literal[False] | None: jitter = random.randint(0, 100) / 100.0 while ( time.monotonic() @@ -146,14 +148,14 @@ def _next_server(self, delay: float) -> tuple | bool | None: # NOTE: This does actually look like it's unreachable but I don't # want to alter the code any more than necessary for the first - # pass. See https://github.com/python-zk/kazoo/issues/772 + # pass. # The loop is basically a sleep with jitter that can be # Add some jitter between host pings while ( # type: ignore[unreachable] time.monotonic() < self.last_attempt + jitter ): return False - delay *= 2 + delay *= 2 # And while not unreachable, this is pointless return None @@ -169,7 +171,7 @@ def __init__( client: KazooClient, retry_sleeper: Any, logger: logging.Logger | None = None, - sasl_options: dict | None = None, + sasl_options: dict[str, str] | None = None, ): self.client = client self.handler = client.handler @@ -188,8 +190,10 @@ def __init__( self._socket: Socket | None = None self._xid: int | None = None - self._rw_server: tuple | None = None - self._ro_mode: Literal[False] | Iterator | None = False + self._rw_server: tuple[str, int] | None = None + self._ro_mode: Iterator[ + Literal[False] | tuple[str, int] | None + ] | Literal[False] | None = False self._connection_routine: Any | None = None @@ -250,7 +254,9 @@ def _server_pinger(self) -> RWPinger: self._socket_error_handling, ) - def _read_header(self, timeout: float | None) -> tuple: + def _read_header( + self, timeout: float | None + ) -> tuple[ReplyHeader, bytes, int]: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) @@ -276,7 +282,9 @@ def _read(self, length: int, timeout: float | None) -> bytes: ): pass else: - s = self.handler.select([self._socket], [], [], timeout)[0] + s = self.handler.select( + [cast("Socket", self._socket)], [], [], timeout + )[0] if not s: # pragma: nocover # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any @@ -376,9 +384,13 @@ def _write(self, msg: bytes, timeout: float | None) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) + # Note: The casts/type: ignore are because mypy can't work out + # self._socket is not None, and I don't want to change any code. with self._socket_error_handling(): while sent < msg_length: - s = self.handler.select([], [self._socket], [], timeout)[1] + s = self.handler.select( + [], [cast("Socket", self._socket)], [], timeout + )[1] if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any @@ -455,7 +467,7 @@ def _read_response( # Determine if its an exists request and a no node error exists_error = ( - # NoNodeError does actually have a code. It's added by a wrapper, + # NoNodeError does actually have a code. It's added by a decorator, # which could possibly be better done via inheritance but this is # less invasive to the existing code. header.err == NoNodeError.code # type: ignore[attr-defined] @@ -615,16 +627,28 @@ def zk_loop(self) -> None: self.client._session_callback(KeeperState.CLOSED) self.logger.log(BLATHER, "Connection stopped") - def _expand_client_hosts(self) -> list: + def _expand_client_hosts(self) -> list[tuple[str, str, int]]: # Expand the entire list in advance so we can randomize it if needed - host_ports = [] + host_ports: list[tuple[str, str, int]] = [] for host, port in self.client.hosts: try: host = host.strip() for rhost in socket.getaddrinfo( host, port, 0, 0, socket.IPPROTO_TCP ): - host_ports.append((host, rhost[4][0], rhost[4][1])) + # FIXME These casts seem to be unnecessary on later + # versions of mypy/python + host_ports.append( + ( + host, + cast( # type: ignore[redundant-cast] + "str", rhost[4][0] + ), + cast( # type: ignore[redundant-cast] + "int", rhost[4][1] + ), + ) + ) except socket.gaierror as e: # Skip hosts that don't resolve self.logger.warning("Cannot resolve %s: %s", host, e) @@ -681,6 +705,9 @@ def _connect_attempt( try: self._xid = 0 read_timeout, connect_timeout = self._connect(host, hostip, port) + # I think the above implies self._socket can't be none, and + # self._read_sock is set up in start but mypy can't tell that. + # Hence the casting. read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -694,7 +721,13 @@ def _connect_attempt( # Ensure our timeout is positive timeout = max([deadline - time.monotonic(), jitter_time]) s = self.handler.select( - [self._socket, self._read_sock], [], [], timeout + [ + cast("Socket", self._socket), + cast("Socket", self._read_sock), + ], + [], + [], + timeout, )[0] if not s: @@ -704,14 +737,14 @@ def _connect_attempt( "outstanding heartbeat ping not received" ) else: - if self._socket in s: + if cast("Socket", self._socket) in s: response = self._read_socket(read_timeout) if response == CLOSE_RESPONSE: break # Check if any requests need sending before proceeding # to process more responses. Otherwise the responses # may choke out the requests. See PR#633. - if self._read_sock in s: + if cast("Socket", self._read_sock) in s: self._send_request(read_timeout, connect_timeout) # Requests act as implicit pings. last_send = time.monotonic() @@ -882,7 +915,7 @@ def _authenticate_with_sasl(self, host: str, timeout: float) -> None: # but again, I want to avoid code changes as much as possible. sasl_cli = ( self.client.sasl_cli # type: ignore[attr-defined] - ) = puresasl.client.SASLClient( + ) = puresasl.client.SASLClient( # type: ignore[no-untyped-call] host=host, **self.sasl_options, # type: ignore[arg-type] ) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 274292613..e1ffc792f 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -1,4 +1,10 @@ -"""Zookeeper Serializers, Deserializers, and NamedTuple objects""" +"""Zookeeper Serializers, Deserializers, and namedtuple objects + +Note: On python3.8, you can't do classvars with NamedTuple. + +FIXME As soon as we get off python3.8 we should change the namedtuple objects +to NamedTuple, as it should get better typechecking. +""" from __future__ import annotations import struct @@ -26,20 +32,24 @@ stat_struct = struct.Struct("!qqqqiiiqiiq") -def read_string(buffer: bytes, offset: int) -> tuple: +def read_string(buffer: bytes, offset: int) -> tuple[str, int]: """Reads an int specified buffer into a string and returns the string and the new offset in the buffer""" length = int_struct.unpack_from(buffer, offset)[0] offset += int_struct.size if length < 0: - return None, offset + # A note: write_str sends a length of -1 to indicate a value of None + # was passed. Not entirely sure where this happens because none of the + # callers of read_string seem to expect a None value. + # Should be ignoring return-value but hound cli... + return None, offset # type: ignore else: index = offset offset += length return buffer[index : index + length].decode("utf-8"), offset -def read_acl(bytes: bytes, offset: int) -> tuple: +def read_acl(bytes: bytes, offset: int) -> tuple[ACL, int]: perms = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size scheme, offset = read_string(bytes, offset) @@ -62,7 +72,7 @@ def write_buffer(bytes: bytes | None) -> bytes: return int_struct.pack(len(bytes)) + bytes -def read_buffer(bytes: bytes, offset: int) -> tuple: +def read_buffer(bytes: bytes, offset: int) -> tuple[bytes | None, int]: length = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if length < 0: @@ -356,7 +366,7 @@ def deserialize( if count == -1: # pragma: nocover return [] - children = [] + children: list[str] = [] for c in range(count): child, offset = read_string(bytes, offset) children.append(child) @@ -394,7 +404,7 @@ def serialize(self) -> bytearray: def deserialize(cls, bytes: bytes, offset: int) -> list[Any]: header = MultiHeader(None, False, None) results = [] - response = None + response: str | bool | ZnodeStat | BaseException | None = None while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 2058d2576..ad835266c 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum -from typing import Callable, NamedTuple +from typing import Any, Callable, NamedTuple # This is a (str, Enum) for backwards compatibility. @@ -162,8 +162,8 @@ class Callback(NamedTuple): """ type: str - func: Callable - args: tuple + func: Callable[..., Any] + args: tuple[Any, ...] class ZnodeStat(NamedTuple): diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 170cd884f..d5bd5e93f 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -18,7 +18,7 @@ import functools import logging import operator -from typing import Any, Callable, Generator, Protocol, TYPE_CHECKING +from typing import Any, Callable, Generator, Protocol, Tuple, TYPE_CHECKING from kazoo.exceptions import NoNodeError, KazooException @@ -397,7 +397,7 @@ def _process_result( self._publish_event(TreeEvent.INITIALIZED) -class TreeEvent(tuple): +class TreeEvent(Tuple[int, Any]): """The immutable event tuple of cache.""" NODE_ADDED = 0 @@ -432,7 +432,7 @@ def make(cls, event_type: int, event_data: Any) -> TreeEvent: return cls((event_type, event_data)) -class NodeData(tuple): +class NodeData(Tuple[str, bytes, Any]): """The immutable node data tuple of cache.""" #: The absolute path string of current node. diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index 0d8c40463..cb444ce4e 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -11,7 +11,7 @@ import datetime import json import socket -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING, cast from kazoo.exceptions import CancelledError @@ -121,7 +121,9 @@ def _encode(self, data_dict: dict[str, Any]) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) def _decode(self, raw: bytes) -> dict[str, Any]: - return json.loads(raw.decode(self._byte_encoding)) + return cast( + "dict[str, Any]", json.loads(raw.decode(self._byte_encoding)) + ) def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 71878b606..9cf433d0c 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -429,7 +429,7 @@ def _child_watching( asy.rawlink(func) return asy - def _establish_sessionwatch(self, state: Any) -> bool: + def _establish_sessionwatch(self, state: KazooState) -> bool: """Register ourself to listen for session events, we shut down if we become lost""" with self._state_change: diff --git a/pyproject.toml b/pyproject.toml index 5af3c0d4c..566bc9c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,11 @@ timeout = 180 ignore_missing_imports = false # Disallow dynamic typing -disallow_any_unimported = false # true +disallow_any_unimported = true disallow_any_expr = false disallow_any_decorated = false # true disallow_any_explicit = false # true -disallow_any_generics = false # true +disallow_any_generics = true disallow_subclassing_any = true # Untyped definitions and calls @@ -81,7 +81,7 @@ hide_error_codes = false pretty = true color_output = true error_summary = true -show_absolute_path = true +show_absolute_path = false # Miscellaneous warn_unused_configs = true @@ -90,15 +90,11 @@ verbosity = 0 # FIXME: As type annotations are introduced, please remove the appropriate # ignore_errors flag below. New modules should NOT be added here! -# no-any-return - We still have some imported modules with no type annotations, -# and I want to avoid code changes as much as possible. - # unused-ignore This is a temporary workaround for the fact that mypy can # produce different errors in 3.8 and 3.14, and I want to avoid code changes # as much as possible. disable_error_code = [ - 'no-any-return', 'unused-ignore', ] @@ -138,15 +134,12 @@ ignore_errors = true [[tool.mypy.overrides]] module = ["eventlet.*"] - ignore_missing_imports = true - #follow_untyped_imports = true + follow_untyped_imports = true [[tool.mypy.overrides]] module = ["gevent.thread"] - ignore_missing_imports = true - #follow_untyped_imports = true + follow_untyped_imports = true [[tool.mypy.overrides]] module = ["puresasl.*"] - ignore_missing_imports = true - #follow_untyped_imports = true + follow_untyped_imports = true From f027cf363e537f32348b7fcaba82b95d48576eee Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Sat, 2 May 2026 22:30:25 +0100 Subject: [PATCH 5/8] After some footling, decided on leaving disallow_any_decorated OFF --- .gitignore | 2 +- kazoo/client.py | 10 +++++-- kazoo/handlers/eventlet.py | 4 +-- kazoo/handlers/gevent.py | 8 ++---- kazoo/handlers/threading.py | 8 ++---- kazoo/handlers/utils.py | 54 ++++++++++++++++++++++++------------ kazoo/interfaces.py | 18 ++++++------ kazoo/protocol/connection.py | 7 +++-- kazoo/recipe/cache.py | 38 +++++++++++++++++-------- pyproject.toml | 8 +++--- 10 files changed, 96 insertions(+), 61 deletions(-) diff --git a/.gitignore b/.gitignore index 8df69daac..cc13a3453 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,7 @@ zookeeper/ .idea .project .pydevproject -.tox +.tox*/ venv*/ /.settings /.metadata diff --git a/kazoo/client.py b/kazoo/client.py index 1bd5ae80b..f7b964966 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -424,8 +424,12 @@ def __init__( # also this should be called with a func that returns nothing as the # 1st argument. - def _retry(*args: Any, **kwargs: Any) -> Any: - return self._retry.copy()(*args, **kwargs) + def _retry( + func: Callable[..., KazooRetry.RETRY_RETURN], + *args: Any, + **kwargs: Any, + ) -> KazooRetry.RETRY_RETURN: + return self._retry.copy()(func, *args, **kwargs) # (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", # variable has type "KazooRetry") so basically self.retry needs to be @@ -500,7 +504,7 @@ def client_state(self) -> KeeperState: return self._state @property - def client_id(self) -> tuple[Any, Any] | None: + def client_id(self) -> tuple[int | None, bytes] | None: """Returns the client id for this Zookeeper session if connected. diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 3cefeb546..24861f570 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -20,7 +20,7 @@ from kazoo.handlers.utils import selector_select if TYPE_CHECKING: - from kazoo.interfaces import Event, Lockable, ReentrantLock, Socket + from kazoo.interfaces import Event, FdLike, Lockable, ReentrantLock, Socket from kazoo.protocol.states import Callback @@ -201,7 +201,7 @@ def create_connection(self, *args: Any, **kwargs: Any) -> Socket: def select( self, *args: Any, **kwargs: Any - ) -> tuple[list[int], list[int], list[int]]: + ) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: with _yield_before_after(): # Following appears to be a bug in mypy (see # https://github.com/python/mypy/issues/6799) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index 1576eef7f..ad806b579 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -24,7 +24,7 @@ from kazoo.handlers import utils if TYPE_CHECKING: - from kazoo.interfaces import HasFileNo, Lockable, Socket + from kazoo.interfaces import FdLike, Lockable, Socket from kazoo.protocol.states import Callback _using_libevent = gevent.__version__.startswith("0.") @@ -146,11 +146,7 @@ def stop(self) -> None: def select( self, *args: Any, **kwargs: Any - ) -> tuple[ - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - ]: + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: # FIXME use the correct arguments, not *args, *kwargs return selector_select( # Likely a bug in mypy (see diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index 7699487c3..377d8e534 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from kazoo.interfaces import ( Event, - HasFileNo, + FdLike, Lockable, ReentrantLock, Socket, @@ -187,11 +187,7 @@ def stop(self) -> None: def select( self, *args: Any, **kwargs: Any - ) -> tuple[ - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - ]: + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: return selector_select(*args, **kwargs) def socket(self) -> Socket: diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 46f017529..ab6669f54 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -11,9 +11,10 @@ import socket import time from types import ModuleType -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING -from kazoo.interfaces import IAsyncResult + +from kazoo.interfaces import IAsyncResult, FdLike if TYPE_CHECKING: from kazoo.interfaces import Socket @@ -323,9 +324,14 @@ def create_tcp_connection( return sock +CapturedResult = TypeVar("CapturedResult") + + def capture_exceptions( async_result: IAsyncResult, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +) -> Callable[ + [Callable[..., CapturedResult]], Callable[..., CapturedResult | None] +]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -333,13 +339,18 @@ def capture_exceptions( """ - def capture(function: Callable[..., Any]) -> Callable[..., Any]: + def capture( + function: Callable[..., CapturedResult] + ) -> Callable[..., CapturedResult | None]: @functools.wraps(function) - def captured_function(*args: Any, **kwargs: Any) -> Any: + def captured_function( + *args: Any, **kwargs: Any + ) -> CapturedResult | None: try: return function(*args, **kwargs) except Exception as exc: async_result.set_exception(exc) + return None return captured_function @@ -348,7 +359,9 @@ def captured_function(*args: Any, **kwargs: Any) -> Any: def wrap( async_result: IAsyncResult, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +) -> Callable[ + [Callable[..., CapturedResult]], Callable[..., CapturedResult | None] +]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a non-None return value. @@ -357,9 +370,13 @@ def wrap( """ - def capture(function: Callable[..., Any]) -> Callable[..., Any]: + def capture( + function: Callable[..., CapturedResult] + ) -> Callable[..., CapturedResult | None]: @capture_exceptions(async_result) - def captured_function(*args: Any, **kwargs: Any) -> Any: + def captured_function( + *args: Any, **kwargs: Any + ) -> CapturedResult | None: value = function(*args, **kwargs) if value is not None: async_result.set(value) @@ -370,7 +387,7 @@ def captured_function(*args: Any, **kwargs: Any) -> Any: return capture -def fileobj_to_fd(fileobj: Any) -> int: +def fileobj_to_fd(fileobj: FdLike) -> int: """Return a file descriptor from a file object. Parameters: @@ -385,22 +402,25 @@ def fileobj_to_fd(fileobj: Any) -> int: if isinstance(fileobj, int): fd = fileobj else: + # FIXME given the protocol I don't think the try/catch/int are + # required. try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise TypeError("Invalid file object: " "{!r}".format(fileobj)) + # FIXME Questionable, just let select deal with it. if fd < 0: raise TypeError("Invalid file descriptor: {}".format(fd)) return fd def selector_select( - rlist: list[Any], - wlist: list[Any], - xlist: list[Any], + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], timeout: float | None = None, selectors_module: ModuleType = selectors, -) -> tuple[list[int], list[int], list[int]]: +) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: """Selector-based drop-in replacement for select to overcome select limitation on a maximum filehandle value. """ @@ -415,7 +435,7 @@ def selector_select( selectors_module.EVENT_WRITE: wlist, } fd_events: defaultdict[int, int] = defaultdict(int) - fd_fileobjs: defaultdict[int, list[int]] = defaultdict(list) + fd_fileobjs: defaultdict[int, list[FdLike]] = defaultdict(list) for event, fileobjs in events_mapping.items(): for fileobj in fileobjs: @@ -431,9 +451,9 @@ def selector_select( # gevent can raise OSError raise ValueError("Invalid event mask or fd") from e - revents: list[int] = [] - wevents: list[int] = [] - xevents: list[int] = [] + revents: list[FdLike] = [] + wevents: list[FdLike] = [] + xevents: list[FdLike] = [] try: ready = selector.select(timeout) finally: diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 6751bf49e..e0f17ec9c 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -18,6 +18,7 @@ Callable, Iterable, Protocol, + Union, TYPE_CHECKING, ) @@ -28,12 +29,15 @@ class HasFileNo(Protocol): - """Protocol for things like select""" + """Protocol for objects that support a fileno method.""" def fileno(self) -> int: ... +FdLike = Union[int, HasFileNo] + + class Socket(HasFileNo, Protocol): """This is for things that provide a socket.socket-like interface. @@ -176,15 +180,11 @@ def stop(self) -> None: @abc.abstractmethod def select( self, - rlist: Iterable[int | HasFileNo], - wlist: Iterable[int | HasFileNo], - xlist: Iterable[int | HasFileNo], + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], timeout: float | None = None, - ) -> tuple[ - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - Iterable[int | HasFileNo], - ]: + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: """A select method that implements Python's select.select API""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index c2e0d5089..7c39737ff 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -11,7 +11,7 @@ import socket import ssl import time -from typing import Any, Iterator, Literal, TYPE_CHECKING, cast +from typing import Any, Iterator, Literal, TypeVar, TYPE_CHECKING, cast from kazoo.exceptions import ( AuthFailedError, @@ -163,6 +163,9 @@ class RWServerAvailable(Exception): """Thrown if a RW Server becomes available""" +ReturnValue = TypeVar("ReturnValue") + + class ConnectionHandler(object): """Zookeeper connection handler""" @@ -203,7 +206,7 @@ def __init__( # This is instance specific to avoid odd thread bug issues in Python # during shutdown global cleanup @contextmanager - def _socket_error_handling(self) -> Any: + def _socket_error_handling(self) -> Iterator[None]: try: yield except (socket.error, select.error) as e: diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index d5bd5e93f..b62515754 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -18,7 +18,15 @@ import functools import logging import operator -from typing import Any, Callable, Generator, Protocol, Tuple, TYPE_CHECKING +from typing import ( + Any, + Callable, + Generator, + Protocol, + TypeVar, + Tuple, + TYPE_CHECKING, +) from kazoo.exceptions import NoNodeError, KazooException @@ -33,6 +41,9 @@ logger = logging.getLogger(__name__) +ReturnValue = TypeVar("ReturnValue") + + class TreeCache(object): """The cache of a ZooKeeper subtree. @@ -52,8 +63,8 @@ def __init__(self, client: KazooClient, path: str): self._state = self.STATE_LATENT self._outstanding_ops = 0 self._is_initialized = False - self._error_listeners: list[Callable[[Exception], Any]] = [] - self._event_listeners: list[Callable[[TreeEvent], Any]] = [] + self._error_listeners: list[Callable[[Exception], None]] = [] + self._event_listeners: list[Callable[[TreeEvent], None]] = [] self._task_queue = client.handler.queue_impl() self._task_thread = None @@ -119,8 +130,8 @@ def close(self) -> None: self._root.on_deleted() def listen( - self, listener: Callable[[TreeEvent], Any] - ) -> Callable[[TreeEvent], Any]: + self, listener: Callable[[TreeEvent], None] + ) -> Callable[[TreeEvent], None]: """Registers a function to listen the cache events. The cache events are changes of local data. They are delivered from @@ -136,8 +147,8 @@ def listen( return listener def listen_fault( - self, listener: Callable[[Exception], Any] - ) -> Callable[[Exception], Any]: + self, listener: Callable[[Exception], None] + ) -> Callable[[Exception], None]: """Registers a function to listen the exceptions. It is possible to meet some exceptions during the cache running. You @@ -175,6 +186,9 @@ def get_children( does not exist. :raises ValueError: If the path is outside of this subtree. :returns: The :class:`frozenset` which including children names. + + # FIXME the default return value should be an empty frozenset, + # returning None is confusing. """ node = self._find_node(path) return default if node is None else frozenset(node._children) @@ -191,7 +205,9 @@ def _find_node(self, path: str) -> TreeNode | None: current_node = current_node._children[node_name] return current_node - def _publish_event(self, event_type: int, event_data: Any = None) -> None: + def _publish_event( + self, event_type: int, event_data: int | None = None + ) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: logger.debug("public event: %r", event) @@ -219,7 +235,7 @@ def _do_background(self) -> None: # release before possible idle del cb, func, args, kwargs - def _session_watcher(self, state: Any) -> None: + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.SUSPENDED: self._publish_event(TreeEvent.CONNECTION_SUSPENDED) elif state == KazooState.CONNECTED: @@ -415,7 +431,7 @@ class TreeEvent(Tuple[int, Any]): event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type: int, event_data: Any) -> TreeEvent: + def make(cls, event_type: int, event_data: int | None = None) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. @@ -455,7 +471,7 @@ def make(cls, path: str, data: bytes, stat: Any) -> NodeData: @contextlib.contextmanager def handle_exception( - listeners: list[Callable[[Exception], Any]], + listeners: list[Callable[[Exception], None]], ) -> Generator[None, None, None]: try: yield diff --git a/pyproject.toml b/pyproject.toml index 566bc9c6f..f7640949a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,9 @@ ignore_missing_imports = false # Disallow dynamic typing disallow_any_unimported = true disallow_any_expr = false -disallow_any_decorated = false # true +# disallow_any_decorated is disabled because it produces an error for almost +# every decorated function (at least in python3.8) +disallow_any_decorated = false disallow_any_explicit = false # true disallow_any_generics = true disallow_subclassing_any = true @@ -94,9 +96,7 @@ verbosity = 0 # produce different errors in 3.8 and 3.14, and I want to avoid code changes # as much as possible. -disable_error_code = [ - 'unused-ignore', -] +disable_error_code = [ 'unused-ignore' ] [[tool.mypy.overrides]] module = [ From 2096f53958f301034ff8b04e0169b37ead6b28ca Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Sun, 3 May 2026 20:47:23 +0100 Subject: [PATCH 6/8] Get rid of most of the Any uses --- kazoo/client.py | 36 +++++++++++++++++----- kazoo/exceptions.py | 4 +-- kazoo/handlers/threading.py | 2 +- kazoo/handlers/utils.py | 2 +- kazoo/interfaces.py | 26 ++++++++++++++-- kazoo/protocol/connection.py | 54 ++++++++++++++++++++++++--------- kazoo/protocol/serialization.py | 27 +++++++++++------ kazoo/recipe/barrier.py | 10 +++--- kazoo/recipe/cache.py | 18 ++++++----- kazoo/recipe/election.py | 4 ++- kazoo/recipe/lock.py | 31 ++++++++++--------- kazoo/recipe/partitioner.py | 6 ++-- kazoo/recipe/party.py | 6 ++-- kazoo/recipe/queue.py | 6 ++-- kazoo/recipe/watchers.py | 3 +- 15 files changed, 159 insertions(+), 76 deletions(-) diff --git a/kazoo/client.py b/kazoo/client.py index f7b964966..90c63a344 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -9,6 +9,7 @@ from os.path import split import re import warnings +from types import TracebackType from typing import ( cast, overload, @@ -19,6 +20,7 @@ Optional, Sequence, Set, + TypedDict, TYPE_CHECKING, ) @@ -79,7 +81,7 @@ from kazoo.recipe.watchers import ChildrenWatch, DataWatch if TYPE_CHECKING: - from kazoo.interfaces import IAsyncResult, IHandler + from kazoo.interfaces import Event, IAsyncResult, IHandler from kazoo.protocol.states import ZnodeStat @@ -114,6 +116,19 @@ WatchFunc = Callable[[WatchedEvent], Optional[bool]] +# Kazoo retry parameters +class KazooRetryParams(TypedDict, total=False): + max_tries: int + delay: float + backoff: int + max_jitter: float + max_delay: float + ignore_expire: bool + sleep_func: Callable[[float], None] + deadline: float + interrupt: Callable[[], bool] + + class KazooClient(object): """An Apache Zookeeper Python client supporting alternate callback handlers and high-level functionality. @@ -125,7 +140,7 @@ class KazooClient(object): """ - def __init__( + def __init__( # type: ignore[misc] self, hosts: str | list[str] = "127.0.0.1:2181", timeout: float = 10.0, @@ -136,9 +151,8 @@ def __init__( sasl_options: dict[str, str] | None = None, read_only: bool | None = None, randomize_hosts: bool = True, - # FIXME the dict should be a TypeDict - connection_retry: KazooRetry | dict[str, Any] | None = None, - command_retry: KazooRetry | dict[str, Any] | None = None, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, logger: logging.Logger | None = None, keyfile: str | None = None, keyfile_password: str | None = None, @@ -472,7 +486,7 @@ def _reset(self) -> None: self._reset_watchers() self._reset_session() self.last_zxid = 0 - self._protocol_version = None + self._protocol_version: int | None = None def _reset_watchers(self) -> None: watchers: list[WatchFunc] = [] @@ -754,7 +768,7 @@ def start(self, timeout: float = 15.0) -> None: "should be created before normal use." ) - def start_async(self) -> Any: + def start_async(self) -> Event: """Asynchronously initiate connection to ZK. :returns: An event object that can be checked to see if the @@ -1998,10 +2012,16 @@ def commit(self) -> list[Any]: def __enter__(self) -> TransactionRequest: return self - def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: """Commit and cleanup accumulated transaction data.""" if not exc_type: self.commit() + return None def _check_tx_state(self) -> None: if self.committed: diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index cb943b47f..ae9341df1 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from typing import Any, Callable, Type +from typing import Callable, Type class KazooException(Exception): @@ -54,7 +54,7 @@ class SASLException(KazooException): """ -def _invalid_error_code() -> Any: +def _invalid_error_code() -> Type[ZookeeperError]: raise RuntimeError("Invalid error code") diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index 377d8e534..829a70107 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -44,7 +44,7 @@ log = logging.getLogger(__name__) -def _to_fileno(obj: Any) -> int: +def _to_fileno(obj: FdLike) -> int: if isinstance(obj, int): fd = int(obj) elif hasattr(obj, "fileno"): diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index ab6669f54..44c1c8461 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -222,7 +222,7 @@ def create_tcp_socket(module: ModuleType) -> Socket: def create_tcp_connection( module: ModuleType, - address: Any, + address: tuple[str, str | int], hostname: str | None = None, timeout: float | None = None, use_ssl: bool = False, diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index e0f17ec9c..a9e4429a1 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -13,6 +13,7 @@ import abc import queue +from types import TracebackType from typing import ( Any, Callable, @@ -79,7 +80,12 @@ class Lockable(Protocol): def __enter__(self) -> None: ... - def __exit__(self, x: Any, y: Any, z: Any) -> None: + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: ... def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: @@ -102,7 +108,12 @@ class ReentrantLock(Protocol): def __enter__(self) -> None: ... - def __exit__(self, x: Any, y: Any, z: Any) -> None: + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: ... def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: @@ -128,6 +139,13 @@ def wait(self, timeout: float | None = None) -> bool: ... +class Threadlike(Protocol): + """Protocol for something like a thread.""" + + def join(self, timeout: float | None = None) -> None: + ... + + SpawnedFunc = Callable[..., None] @@ -223,7 +241,9 @@ def async_result(self) -> IAsyncResult: handler""" @abc.abstractmethod - def spawn(self, func: SpawnedFunc, *args: Any, **kwargs: Any) -> Any: + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> Threadlike: """Spawn a function to run asynchronously :param args: args to call the function with. diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 7c39737ff..94640cd2f 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -11,7 +11,17 @@ import socket import ssl import time -from typing import Any, Iterator, Literal, TypeVar, TYPE_CHECKING, cast +from typing import ( + Callable, + ContextManager, + Iterator, + Literal, + TypeVar, + TYPE_CHECKING, + cast, + overload, +) +from typing_extensions import Buffer from kazoo.exceptions import ( AuthFailedError, @@ -45,12 +55,13 @@ ) from kazoo.retry import ( ForceRetryError, + KazooRetry, RetryFailedError, ) if TYPE_CHECKING: from kazoo.client import KazooClient, WatchFunc - from kazoo.interfaces import Socket + from kazoo.interfaces import Socket, Threadlike try: import puresasl @@ -84,7 +95,7 @@ # removed from Python3+ -def buffer(obj: Any, offset: int = 0) -> memoryview: +def buffer(obj: Buffer, offset: int = 0) -> memoryview: return memoryview(obj)[offset:] @@ -105,8 +116,8 @@ class RWPinger(object): def __init__( self, hosts: list[tuple[str, int]], - connection_func: Any, - socket_handling: Any, + connection_func: Callable[..., Socket], + socket_handling: Callable[[], ContextManager[None]], ): self.hosts = hosts self.connection = connection_func @@ -172,7 +183,7 @@ class ConnectionHandler(object): def __init__( self, client: KazooClient, - retry_sleeper: Any, + retry_sleeper: KazooRetry, logger: logging.Logger | None = None, sasl_options: dict[str, str] | None = None, ): @@ -198,7 +209,7 @@ def __init__( Literal[False] | tuple[str, int] | None ] | Literal[False] | None = False - self._connection_routine: Any | None = None + self._connection_routine: Threadlike | None = None self.sasl_options = sasl_options self.sasl_cli = None @@ -312,12 +323,24 @@ def _read(self, length: int, timeout: float | None) -> bytes: remaining -= len(chunk) return b"".join(msgparts) + @overload + def _invoke( + self, timeout: float | None, request: Connect + ) -> tuple[Connect, int | None]: + ... + + @overload + def _invoke( + self, timeout: float | None, request: Auth, xid: int + ) -> int | None: + ... + def _invoke( self, timeout: float | None, - request: Any, + request: Auth | Connect, xid: int | None = None, - ) -> Any: + ) -> tuple[Connect, int | None] | int | None: """A special writer used during connection establishment only""" self._submit(request, timeout, xid) @@ -346,7 +369,9 @@ def _invoke( if hasattr(request, "deserialize"): try: - obj, _ = request.deserialize(msg, 0) + # This is a bit of an annoying ignore as I've just done a + # hasattr... + obj, _ = request.deserialize(msg, 0) # type:ignore[union-attr] except Exception: self.logger.exception( "Exception raised during deserialization " @@ -363,7 +388,7 @@ def _invoke( def _submit( self, - request: Any, + request: Auth | Connect | Ping | SASL, timeout: float | None, xid: int | None = None, ) -> None: @@ -451,7 +476,7 @@ def _read_watch_event(self, buffer: bytes, offset: int) -> None: def _read_response( self, - header: Any, + header: ReplyHeader, buffer: bytes, offset: int, ) -> object | None: @@ -660,7 +685,7 @@ def _expand_client_hosts(self) -> list[tuple[str, str, int]]: random.shuffle(host_ports) return host_ports - def _connect_loop(self, retry: Any) -> object: + def _connect_loop(self, retry: KazooRetry) -> object: # Iterate through the hosts a full cycle before starting over status = None host_ports = self._expand_client_hosts() @@ -687,7 +712,7 @@ def _connect_attempt( host: str, hostip: str, port: int, - retry: Any, + retry: KazooRetry, ) -> object: client = self.client KazooTimeoutError = self.handler.timeout_exception @@ -725,6 +750,7 @@ def _connect_attempt( timeout = max([deadline - time.monotonic(), jitter_time]) s = self.handler.select( [ + # FIXME we should know these aren't None cast("Socket", self._socket), cast("Socket", self._read_sock), ], diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index e1ffc792f..29a55f35d 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -9,9 +9,9 @@ import struct from collections import namedtuple -from typing import Any, ClassVar, Sequence, TYPE_CHECKING +from typing import ClassVar, Sequence, Union, TYPE_CHECKING -from kazoo.exceptions import EXCEPTIONS +from kazoo.exceptions import EXCEPTIONS, ZookeeperError from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id @@ -136,7 +136,7 @@ def serialize(self) -> bytearray: return b @classmethod - def deserialize(cls, bytes: bytes, offset: int) -> tuple[Any, int]: + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Connect, int]: proto_version, timeout, session_id = int_int_long_struct.unpack_from( bytes, offset ) @@ -387,8 +387,13 @@ def serialize(self) -> bytearray: return b +# FIXME Transaction class should move after Create2 +Transaction_Types = Union[Create, "Create2", Delete, SetData, CheckVersion] +Transaction_Response = Union[str, bool, ZnodeStat, ZookeeperError, None] + + class Transaction(namedtuple("Transaction", "operations")): - operations: list[Any] + operations: list[Transaction_Types] type: ClassVar[int] = 14 @@ -401,10 +406,12 @@ def serialize(self) -> bytearray: return b + multiheader_struct.pack(-1, True, -1) @classmethod - def deserialize(cls, bytes: bytes, offset: int) -> list[Any]: + def deserialize( + cls, bytes: bytes, offset: int + ) -> list[Transaction_Response]: header = MultiHeader(None, False, None) - results = [] - response: str | bool | ZnodeStat | BaseException | None = None + results: list[Transaction_Response] = [] + response: Transaction_Response = None while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) @@ -425,8 +432,10 @@ def deserialize(cls, bytes: bytes, offset: int) -> list[Any]: return results @staticmethod - def unchroot(client: KazooClient, response: list[Any]) -> list[Any]: - resp = [] + def unchroot( + client: KazooClient, response: list[Transaction_Response] + ) -> list[Transaction_Response]: + resp: list[Transaction_Response] = [] for result in response: if isinstance(result, str): resp.append(client.unchroot(result)) diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index efe6cd053..26ffc9357 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -10,10 +10,10 @@ import os import socket import uuid -from typing import Any, Literal, TYPE_CHECKING +from typing import Literal, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent if TYPE_CHECKING: from kazoo.client import KazooClient @@ -71,7 +71,7 @@ def wait(self, timeout: float | None = None) -> bool: """ cleared = self.client.handler.event_object() - def wait_for_clear(event: Any) -> None: + def wait_for_clear(event: WatchedEvent) -> None: if event.type == EventType.DELETED: cleared.set() @@ -158,7 +158,7 @@ def _inner_enter(self) -> Literal[True]: except NodeExistsError: pass - def created(event: Any) -> None: + def created(event: WatchedEvent) -> None: if event.type == EventType.CREATED: ready.set() @@ -201,7 +201,7 @@ def _inner_leave(self) -> bool: ready = self.client.handler.event_object() - def deleted(event: Any) -> None: + def deleted(event: WatchedEvent) -> None: if event.type == EventType.DELETED: ready.set() diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index b62515754..085af340a 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -26,16 +26,17 @@ TypeVar, Tuple, TYPE_CHECKING, + Union, ) from kazoo.exceptions import NoNodeError, KazooException from kazoo.protocol.paths import _prefix_root, join as kazoo_join -from kazoo.protocol.states import KazooState, EventType +from kazoo.protocol.states import KazooState, EventType, ZnodeStat if TYPE_CHECKING: from kazoo.client import KazooClient, WatchFunc - from kazoo.interfaces import IAsyncResult + from kazoo.interfaces import IAsyncResult, Threadlike from kazoo.protocol.states import WatchedEvent logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def __init__(self, client: KazooClient, path: str): self._error_listeners: list[Callable[[Exception], None]] = [] self._event_listeners: list[Callable[[TreeEvent], None]] = [] self._task_queue = client.handler.queue_impl() - self._task_thread = None + self._task_thread: Threadlike | None = None def start(self) -> None: """Starts the cache. @@ -366,7 +367,7 @@ def _process_watch(self, watched_event: WatchedEvent) -> None: self._refresh_children() def _process_result( - self, method_name: str, path: str, result: Any + self, method_name: str, path: str, result: IAsyncResult ) -> None: logger.debug("process_result: %s %s", method_name, path) if method_name == "exists": @@ -413,9 +414,12 @@ def _process_result( self._publish_event(TreeEvent.INITIALIZED) -class TreeEvent(Tuple[int, Any]): +# mypy doesn't like inheriting from tuples, though TBH this would look a lot +# better using NamedTuple though the event_type needs sorting. +class TreeEvent(Tuple[int, Union[int, None]]): # type: ignore[misc] """The immutable event tuple of cache.""" + # FIXME These should be an enum. NODE_ADDED = 0 NODE_UPDATED = 1 NODE_REMOVED = 2 @@ -448,7 +452,7 @@ def make(cls, event_type: int, event_data: int | None = None) -> TreeEvent: return cls((event_type, event_data)) -class NodeData(Tuple[str, bytes, Any]): +class NodeData(Tuple[str, bytes, ZnodeStat]): """The immutable node data tuple of cache.""" #: The absolute path string of current node. @@ -461,7 +465,7 @@ class NodeData(Tuple[str, bytes, Any]): stat = property(operator.itemgetter(2)) @classmethod - def make(cls, path: str, data: bytes, stat: Any) -> NodeData: + def make(cls, path: str, data: bytes, stat: ZnodeStat) -> NodeData: """Creates a new NodeData tuple. :returns: A :class:`~kazoo.recipe.cache.NodeData` instance. diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 09fe43349..3fce5fb2d 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -47,7 +47,9 @@ def __init__( """ self.lock = client.Lock(path, identifier) - def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def run( + self, func: Callable[..., None], *args: Any, **kwargs: Any + ) -> None: """Contend for the leadership This call will block until either this contender is cancelled diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 29ce5cd91..6f3236db2 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -21,12 +21,12 @@ import time import uuid from typing import ( - Any, Iterable, Literal, Pattern, TYPE_CHECKING, ) +from types import TracebackType from kazoo.exceptions import ( CancelledError, @@ -34,7 +34,7 @@ LockTimeout, NoNodeError, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent from kazoo.retry import ( ForceRetryError, KazooRetry, @@ -235,7 +235,7 @@ def acquire( finally: self._acquire_method_lock.release() - def _watch_session(self, state: Any) -> bool: + def _watch_session(self, state: KazooState) -> bool: self.wake_event.set() return True @@ -301,7 +301,7 @@ def _inner_acquire( finally: self.client.remove_listener(self._watch_session) - def _watch_predecessor(self, event: Any) -> None: + def _watch_predecessor(self, event: WatchedEvent) -> None: self.wake_event.set() def _get_predecessor(self, node: str) -> str | None: @@ -433,11 +433,12 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Any, - exc_value: Any, - traceback: Any, - ) -> None: + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: self.release() + return None class WriteLock(Lock): @@ -664,7 +665,7 @@ def _inner_acquire( w = _Watch(duration=timeout) w.start() - # This is passing bytes data, but self.client.Lock expects a str, + # FIXME This is passing bytes data, but self.client.Lock expects a str, # which I think is a bug in this code. However, I don't want to # change any code at this point, so we just ignore the type error here. lock = self.client.Lock( @@ -695,10 +696,10 @@ def _inner_acquire( finally: lock.release() - def _watch_lease_change(self, event: Any) -> None: + def _watch_lease_change(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_lease(self, data: Any = None) -> bool: + def _get_lease(self) -> bool: # Make sure the session is still valid if self._session_expired: raise ForceRetryError("Retry on session loss at top") @@ -727,7 +728,7 @@ def _get_lease(self, data: Any = None) -> bool: # Return current state return self.is_acquired - def _watch_session(self, state: Any) -> bool | None: + def _watch_session(self, state: KazooState) -> bool | None: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() @@ -784,8 +785,8 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Any, - exc_value: Any, - traceback: Any, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.release() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 9cf433d0c..eaa814876 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from kazoo.client import KazooClient - from kazoo.interfaces import IAsyncResult + from kazoo.interfaces import Event, IAsyncResult from kazoo.recipe.lock import Lock log = logging.getLogger(__name__) @@ -157,7 +157,7 @@ def __init__( identifier: str | None = None, time_boundary: float = 30, max_reaction_time: float = 1, - state_change_event: Any | None = None, + state_change_event: Event | None = None, ): """Create a :class:`~SetPartitioner` instance @@ -288,7 +288,7 @@ def _fail_out(self) -> None: except KazooException: # pragma: nocover pass - def _allocate_transition(self, result: Any) -> None: + def _allocate_transition(self, result: IAsyncResult) -> None: """Called when in allocating mode, and the children settled""" # Did we get an exception waiting for children to settle? diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index f28f2be11..1fc1340b7 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -11,7 +11,7 @@ from __future__ import annotations import uuid -from typing import Any, Iterator, TYPE_CHECKING +from typing import Iterator, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError @@ -47,7 +47,7 @@ def _ensure_parent(self) -> None: self.client.ensure_path(self.path) self.ensured_path = True - def join(self) -> Any: + def join(self) -> None: """Join the party""" return self.client.retry(self._inner_join) @@ -69,7 +69,7 @@ def _inner_join(self) -> None: # suspended connection self.participating = True - def leave(self) -> Any: + def leave(self) -> bool: """Leave the party""" self.participating = False return self.client.retry(self._inner_leave) diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index dae3ed5fd..85a866764 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -13,10 +13,10 @@ from __future__ import annotations import uuid -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent from kazoo.retry import ForceRetryError if TYPE_CHECKING: @@ -302,7 +302,7 @@ def _inner_get(self, timeout: float | None) -> bytes | None: canceled = False value = [] - def check_for_updates(event: Any | None) -> None: + def check_for_updates(event: WatchedEvent | None) -> None: if event is not None and event.type != EventType.CHILD: return with lock: diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index 32f38042f..983ee869d 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -193,7 +193,8 @@ def _log_func_exception( # The type ignores here are because mypy can't figure out that # 1) self._func can't ever be None (fingers crossed) # 2) the function can be called with 2 arguments or with 3 - # arguments + # arguments (though that could possibly be done with better + # typing) result = self._func( # type: ignore[call-arg, misc] data, stat, event ) From b40a0da9d62def86802661a77e517f86da0fac15 Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Sun, 3 May 2026 21:51:05 +0100 Subject: [PATCH 7/8] It worked for me at home... --- .coveragerc | 2 +- setup.cfg | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 0c53d00b0..c79b53748 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,7 +7,7 @@ omit = # Note - this is a copy of the default exclusions from coverage 7.10.1 [report] -exclude_lines = +exclude_lines = #\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER) ^\s*(((async )?def .*?)?\)(\s*->.*?)?:\s*)?\.\.\.\s*(#|$) if (typing\.)?TYPE_CHECKING: diff --git a/setup.cfg b/setup.cfg index 89e8e8442..4dfd58163 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,6 +65,7 @@ test = eventlet>=0.17.1 ; implementation_name!='pypy' pyjks pyopenssl + typing-extensions eventlet = eventlet>=0.17.1 @@ -83,11 +84,15 @@ typing = mypy>=0.991 types-gevent +other = + typing-extensions + alldeps = %(dev)s %(eventlet)s %(gevent)s %(sasl)s %(docs)s + %(other)s %(typing)s From 2530d478ec063b8e7cb9ad8923160815d6d990a0 Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Mon, 4 May 2026 20:13:22 +0100 Subject: [PATCH 8/8] Ripped out a whole bunch of Any --- kazoo/client.py | 11 +++++---- kazoo/handlers/eventlet.py | 18 +++++++++++--- kazoo/handlers/gevent.py | 2 +- kazoo/handlers/utils.py | 21 +++++++++------- kazoo/interfaces.py | 6 ++++- kazoo/protocol/connection.py | 2 +- kazoo/recipe/cache.py | 39 +++++++++++++++++++++++------ kazoo/recipe/election.py | 10 ++++++-- kazoo/recipe/lease.py | 18 ++++++++------ kazoo/recipe/partitioner.py | 14 ++++++++--- kazoo/recipe/watchers.py | 48 ++++++++++++++++++++++++++++++------ kazoo/retry.py | 5 +++- pyproject.toml | 3 +-- 13 files changed, 146 insertions(+), 51 deletions(-) diff --git a/kazoo/client.py b/kazoo/client.py index 90c63a344..b7a881eaa 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -23,6 +23,7 @@ TypedDict, TYPE_CHECKING, ) +from typing_extensions import ParamSpec from kazoo.exceptions import ( AuthFailedError, @@ -115,6 +116,8 @@ # Signatures for get, get_children and exists watches WatchFunc = Callable[[WatchedEvent], Optional[bool]] +GenericArgs = ParamSpec("GenericArgs") + # Kazoo retry parameters class KazooRetryParams(TypedDict, total=False): @@ -436,12 +439,10 @@ def __init__( # type: ignore[misc] # to avoid shared retry counts self._retry = self.retry - # also this should be called with a func that returns nothing as the - # 1st argument. def _retry( - func: Callable[..., KazooRetry.RETRY_RETURN], - *args: Any, - **kwargs: Any, + func: Callable[GenericArgs, KazooRetry.RETRY_RETURN], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, ) -> KazooRetry.RETRY_RETURN: return self._retry.copy()(func, *args, **kwargs) diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 24861f570..ee8ed47d6 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -20,7 +20,14 @@ from kazoo.handlers.utils import selector_select if TYPE_CHECKING: - from kazoo.interfaces import Event, FdLike, Lockable, ReentrantLock, Socket + from kazoo.interfaces import ( + Event, + FdLike, + IHandler, + Lockable, + ReentrantLock, + Socket, + ) from kazoo.protocol.states import Callback @@ -51,7 +58,7 @@ class TimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler: Any): + def __init__(self, handler: IHandler): super(AsyncResult, self).__init__( handler, green_threading.Condition, # type: ignore[attr-defined] @@ -59,6 +66,7 @@ def __init__(self, handler: Any): ) +# FIXME This should inherit from IHandler class SequentialEventletHandler(object): """Eventlet handler for sequentially executing callbacks. @@ -174,7 +182,7 @@ def stop(self) -> None: self._started = False atexit.unregister(self.stop) - def socket(self, *args: Any, **kwargs: Any) -> Socket: + def socket(self) -> Socket: return utils.create_tcp_socket(green_socket) def create_socket_pair(self) -> tuple[Socket, Socket]: @@ -196,9 +204,11 @@ def rlock_object(self) -> ReentrantLock: green_threading.RLock(), # type: ignore[attr-defined] ) + # FIXME fix parameters def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) + # FIXME fix parameters def select( self, *args: Any, **kwargs: Any ) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: @@ -212,7 +222,7 @@ def select( ) def async_result(self) -> AsyncResult: - return AsyncResult(self) + return AsyncResult(self) # type: ignore[arg-type] def spawn( self, func: Any, *args: Any, **kwargs: Any diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index ad806b579..7a3c215cf 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -156,7 +156,7 @@ def select( **kwargs, # type: ignore[misc] ) - def socket(self, *args: Any, **kwargs: Any) -> Socket: + def socket(self) -> Socket: # See above return utils.create_tcp_socket(socket) diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 44c1c8461..019cc2d22 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -12,7 +12,7 @@ import time from types import ModuleType from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING - +from typing_extensions import ParamSpec from kazoo.interfaces import IAsyncResult, FdLike @@ -325,12 +325,14 @@ def create_tcp_connection( CapturedResult = TypeVar("CapturedResult") +GenericArgs = ParamSpec("GenericArgs") def capture_exceptions( async_result: IAsyncResult, ) -> Callable[ - [Callable[..., CapturedResult]], Callable[..., CapturedResult | None] + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], ]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -340,11 +342,11 @@ def capture_exceptions( """ def capture( - function: Callable[..., CapturedResult] - ) -> Callable[..., CapturedResult | None]: + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @functools.wraps(function) def captured_function( - *args: Any, **kwargs: Any + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs ) -> CapturedResult | None: try: return function(*args, **kwargs) @@ -360,7 +362,8 @@ def captured_function( def wrap( async_result: IAsyncResult, ) -> Callable[ - [Callable[..., CapturedResult]], Callable[..., CapturedResult | None] + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], ]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a @@ -371,11 +374,11 @@ def wrap( """ def capture( - function: Callable[..., CapturedResult] - ) -> Callable[..., CapturedResult | None]: + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @capture_exceptions(async_result) def captured_function( - *args: Any, **kwargs: Any + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs ) -> CapturedResult | None: value = function(*args, **kwargs) if value is not None: diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index a9e4429a1..71de8c21b 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -69,6 +69,9 @@ def setblocking(self, flags: bool) -> None: def setsockopt(self, level: int, optname: int, value: int) -> None: ... + def getpeername(self) -> tuple[str, int]: + ... + class Lockable(Protocol): """This is what threading.Lock implements. @@ -210,6 +213,7 @@ def select( def socket(self) -> Socket: """A socket method that implements Python's socket.socket API""" + # FIXME This should have a proper set of parameters. @abc.abstractmethod def create_connection(self, *args: Any, **kwargs: Any) -> Socket: """A socket method that implements Python's socket.create_connection @@ -367,7 +371,7 @@ def rawlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """ @abc.abstractmethod - def unlink(self, callback: Callable[[IAsyncResult], Any]) -> None: + def unlink(self, callback: Callable[[IAsyncResult], None]) -> None: """Remove the callback set by :meth:`rawlink` :param callback: A callback function to remove. diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 94640cd2f..e4ed1a17d 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -116,7 +116,7 @@ class RWPinger(object): def __init__( self, hosts: list[tuple[str, int]], - connection_func: Callable[..., Socket], + connection_func: Callable[[tuple[str, int]], Socket], socket_handling: Callable[[], ContextManager[None]], ): self.hosts = hosts diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 085af340a..cee96c82f 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -27,6 +27,7 @@ Tuple, TYPE_CHECKING, Union, + overload, ) @@ -207,7 +208,7 @@ def _find_node(self, path: str) -> TreeNode | None: return current_node def _publish_event( - self, event_type: int, event_data: int | None = None + self, event_type: int, event_data: NodeData | None = None ) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: @@ -219,7 +220,27 @@ def _do_publish_event(self, event: TreeEvent) -> None: with handle_exception(self._error_listeners): listener(event) + @overload def _in_background( + self, func: Callable[[TreeEvent], None], event: TreeEvent + ) -> None: + ... + + @overload + def _in_background(self, func: Callable[[], None]) -> None: + ... + + @overload + def _in_background( + self, + func: Callable[[str, str, IAsyncResult], None], + method_name: str, + path: str, + result: IAsyncResult, + ) -> None: + ... + + def _in_background( # type: ignore[misc] self, func: Callable[..., Any], *args: Any, **kwargs: Any ) -> None: self._task_queue.put((func, args, kwargs)) @@ -319,8 +340,10 @@ def on_deleted(self) -> None: del self._parent._children[child] self._reset_watchers() - def _publish_event(self, *args: Any, **kwargs: Any) -> Any: - return self._tree._publish_event(*args, **kwargs) + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: + return self._tree._publish_event(event_type, event_data) def _reset_watchers(self) -> None: client = self._tree._client @@ -414,9 +437,9 @@ def _process_result( self._publish_event(TreeEvent.INITIALIZED) -# mypy doesn't like inheriting from tuples, though TBH this would look a lot -# better using NamedTuple though the event_type needs sorting. -class TreeEvent(Tuple[int, Union[int, None]]): # type: ignore[misc] +# FIXME these Tuple-based classes would look a lot better using NamedTuple +# though the event_type in TreeEvent needs sorting. +class TreeEvent(Tuple[int, Union["NodeData", None]]): """The immutable event tuple of cache.""" # FIXME These should be an enum. @@ -435,7 +458,9 @@ class TreeEvent(Tuple[int, Union[int, None]]): # type: ignore[misc] event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type: int, event_data: int | None = None) -> TreeEvent: + def make( + cls, event_type: int, event_data: NodeData | None = None + ) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 3fce5fb2d..1e28517b7 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -7,13 +7,16 @@ from __future__ import annotations -from typing import Any, Callable, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING +from typing_extensions import ParamSpec from kazoo.exceptions import CancelledError if TYPE_CHECKING: from kazoo.client import KazooClient +GenericArgs = ParamSpec("GenericArgs") + class Election(object): """Kazoo Basic Leader Election @@ -48,7 +51,10 @@ def __init__( self.lock = client.Lock(path, identifier) def run( - self, func: Callable[..., None], *args: Any, **kwargs: Any + self, + func: Callable[GenericArgs, None], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, ) -> None: """Contend for the leadership diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index cb444ce4e..1e4c9cf85 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -11,7 +11,7 @@ import datetime import json import socket -from typing import Any, Callable, TYPE_CHECKING, cast +from typing import Callable, TypedDict, TYPE_CHECKING, cast from kazoo.exceptions import CancelledError @@ -19,6 +19,12 @@ from kazoo.client import KazooClient +class Lease(TypedDict): + version: int + holder: str + end: str + + class NonBlockingLease(object): """Exclusive lease that does not block. @@ -106,7 +112,7 @@ def _attempt_obtaining( return client.delete(holder_path) end_lease = (now + duration).strftime(self._date_format) - new_data = { + new_data: Lease = { "version": self._version, "holder": ident, "end": end_lease, @@ -117,13 +123,11 @@ def _attempt_obtaining( except CancelledError: pass - def _encode(self, data_dict: dict[str, Any]) -> bytes: + def _encode(self, data_dict: Lease) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) - def _decode(self, raw: bytes) -> dict[str, Any]: - return cast( - "dict[str, Any]", json.loads(raw.decode(self._byte_encoding)) - ) + def _decode(self, raw: bytes) -> Lease: + return cast("Lease", json.loads(raw.decode(self._byte_encoding))) def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index eaa814876..19c48c28c 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -25,7 +25,7 @@ import os import socket from enum import Enum -from typing import Any, Callable, Iterator, Sequence, TYPE_CHECKING +from typing import Callable, Iterator, Sequence, TYPE_CHECKING from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState @@ -402,11 +402,13 @@ def _abort_lock_acquisition(self) -> None: self._child_watching(self._allocate_transition, client_handler=True) + # FIXME This is only ever called with func=self._allocation_transition, but + # I didn't want to change the code. def _child_watching( self, - func: Callable[..., Any] | None = None, + func: Callable[[IAsyncResult], None] | None = None, client_handler: bool = False, - ) -> Any: + ) -> IAsyncResult: """Called when children are being watched to stabilize This actually returns immediately, child watcher spins up a @@ -425,7 +427,11 @@ def _child_watching( # to ensure that the rawlink's it might use won't be # blocked if client_handler: - func = partial(self._client.handler.spawn, func) + # FIXME This feels wrong, but it may be because partial is + # confusing things. + func = partial( # type: ignore[assignment] + self._client.handler.spawn, func + ) asy.rawlink(func) return asy diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index 983ee869d..aa82e5478 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -17,7 +17,16 @@ import logging import time import warnings -from typing import Any, List, Callable, Optional, Union, TYPE_CHECKING +from typing import ( + Any, + List, + Callable, + Optional, + Union, + TYPE_CHECKING, + overload, +) +from typing_extensions import ParamSpec from kazoo.exceptions import ConnectionClosedError, NoNodeError, KazooException from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat @@ -32,10 +41,14 @@ _STOP_WATCHING = object() +GenericArgs = ParamSpec("GenericArgs") -def _ignore_closed(func: Callable[..., None]) -> Callable[..., None]: + +def _ignore_closed( + func: Callable[GenericArgs, None] +) -> Callable[GenericArgs, None]: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> None: + def wrapper(*args: GenericArgs.args, **kwargs: GenericArgs.kwargs) -> None: try: return func(*args, **kwargs) except ConnectionClosedError: @@ -45,9 +58,9 @@ def wrapper(*args: Any, **kwargs: Any) -> None: DataWatchFunc = Union[ - Callable[[Optional[str], Optional[ZnodeStat]], Optional[bool]], + Callable[[Optional[bytes], Optional[ZnodeStat]], Optional[bool]], Callable[ - [Optional[str], Optional[ZnodeStat], Optional[WatchedEvent]], + [Optional[bytes], Optional[ZnodeStat], Optional[WatchedEvent]], Optional[bool], ], ] @@ -104,11 +117,32 @@ def my_func(data, stat, event): """ + @overload def __init__( self, client: KazooClient, path: str, func: DataWatchFunc | None = None, + ): + ... + + # FIXME This would get a @warnings.deprecated in py13+ + @overload + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): + ... + + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, *args: Any, **kwargs: Any, ): @@ -180,7 +214,7 @@ def __call__(self, func: DataWatchFunc) -> DataWatchFunc: def _log_func_exception( self, - data: Any, + data: bytes | None, stat: ZnodeStat | None, event: WatchedEvent | None = None, ) -> None: @@ -246,7 +280,7 @@ def _get_data(self, event: WatchedEvent | None = None) -> None: if initial_version != self._version or not self._ever_called: self._log_func_exception(data, stat, event) - def _watcher(self, event: KazooState) -> None: + def _watcher(self, event: WatchedEvent) -> None: self._get_data(event=event) def _set_watch(self, state: KazooState) -> None: diff --git a/kazoo/retry.py b/kazoo/retry.py index d406696ff..9e4e0c9a8 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -116,7 +116,10 @@ def copy(self) -> KazooRetry: return obj def __call__( - self, func: Callable[..., RETRY_RETURN], *args: Any, **kwargs: Any + self, + func: Callable[..., RETRY_RETURN], + *args: Any, + **kwargs: Any, ) -> RETRY_RETURN: """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/pyproject.toml b/pyproject.toml index f7640949a..48f3b5fae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,10 +92,9 @@ verbosity = 0 # FIXME: As type annotations are introduced, please remove the appropriate # ignore_errors flag below. New modules should NOT be added here! -# unused-ignore This is a temporary workaround for the fact that mypy can +# This is a temporary workaround for the fact that mypy can # produce different errors in 3.8 and 3.14, and I want to avoid code changes # as much as possible. - disable_error_code = [ 'unused-ignore' ] [[tool.mypy.overrides]]