Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions testtools/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import types
import unittest
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING, TypeVar, cast, overload
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast, overload
from unittest.case import SkipTest

T = TypeVar("T")
Expand Down Expand Up @@ -88,6 +88,8 @@ class _ExpectedFailure(Exception):


# TypeVar for decorators
_P = ParamSpec("_P")
_R = TypeVar("_R")
_F = TypeVar("_F", bound=Callable[..., object])


Expand Down Expand Up @@ -390,10 +392,10 @@ def _formatTypes(

def addCleanup(
self,
function: Callable[..., object],
function: Callable[_P, _R],
/,
*arguments: object,
**keywordArguments: object,
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
"""Add a cleanup function to be called after tearDown.

Expand All @@ -407,7 +409,7 @@ def addCleanup(
Cleanup functions are always called before a test finishes running,
even if setUp is aborted by an exception.
"""
self._cleanups.append((function, arguments, keywordArguments))
self._cleanups.append((function, args, kwargs))

def addOnException(self, handler: "Callable[[ExcInfo], None]") -> None:
"""Add a handler to be called when an exception occurs in test code.
Expand Down Expand Up @@ -503,26 +505,24 @@ def assertIsInstance( # type: ignore[override]
def assertRaises(
self,
expected_exception: type[BaseException] | tuple[type[BaseException]],
callable: Callable[..., object],
*args: object,
**kwargs: object,
callable: Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> BaseException: ...

@overload # type: ignore[override]
def assertRaises(
self,
expected_exception: type[BaseException] | tuple[type[BaseException]],
callable: None = ...,
*args: object,
**kwargs: object,
) -> "_AssertRaisesContext": ...

def assertRaises( # type: ignore[override]
self,
expected_exception: type[BaseException] | tuple[type[BaseException]],
callable: Callable[..., object] | None = None,
*args: object,
**kwargs: object,
callable: Callable[_P, _R] | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> "_AssertRaisesContext | BaseException":
"""Fail unless an exception of class expected_exception is thrown
by callable when invoked with arguments args and keyword
Expand Down Expand Up @@ -678,9 +678,9 @@ def defaultTestResult(self) -> TestResult:
def expectFailure(
self,
reason: str,
predicate: Callable[..., object],
*args: object,
**kwargs: object,
predicate: Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
"""Check that a test fails in a particular way.

Expand Down Expand Up @@ -1349,7 +1349,10 @@ class Nullary:
"""

def __init__(
self, callable_object: Callable[..., object], *args: object, **kwargs: object
self,
callable_object: Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
self._callable_object = callable_object
self._args = args
Expand Down