Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ notice-rgx = "Copyright \\(c\\) Microsoft Corporation\\.\\s*\\n.*Licensed under
# Temporary ignores for pyrit/ subdirectories until issue #1176
# https://github.com/Azure/PyRIT/issues/1176 is fully resolved
# TODO: Remove these ignores once the issues are fixed
"pyrit/{auxiliary_attacks,exceptions,models,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
"pyrit/{auxiliary_attacks,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
# Backend API routes raise HTTPException handled by FastAPI, not true exceptions
"pyrit/backend/**/*.py" = ["DOC501"]
"pyrit/__init__.py" = ["D104"]
Expand Down
2 changes: 2 additions & 0 deletions pyrit/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Exception classes, retry helpers, and execution context utilities."""

from pyrit.exceptions.exception_classes import (
BadRequestException,
EmptyResponseException,
Expand Down
124 changes: 114 additions & 10 deletions pyrit/exceptions/exception_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,46 @@


def _get_custom_result_retry_max_num_attempts() -> int:
"""Get the maximum number of retry attempts for custom result retry decorator."""
"""
Get the maximum number of retry attempts for custom result retry decorator.

Returns:
int: Maximum retry attempts.

"""
return int(os.getenv("CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS", 10))


def get_retry_max_num_attempts() -> int:
"""Get the maximum number of retry attempts."""
"""
Get the maximum number of retry attempts.

Returns:
int: Maximum retry attempts.

"""
return int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 10))


def _get_retry_wait_min_seconds() -> int:
"""Get the minimum wait time in seconds between retries."""
"""
Get the minimum wait time in seconds between retries.

Returns:
int: Minimum wait duration in seconds.

"""
return int(os.getenv("RETRY_WAIT_MIN_SECONDS", 5))


def _get_retry_wait_max_seconds() -> int:
"""Get the maximum wait time in seconds between retries."""
"""
Get the maximum wait time in seconds between retries.

Returns:
int: Maximum wait duration in seconds.

"""
return int(os.getenv("RETRY_WAIT_MAX_SECONDS", 220))


Expand Down Expand Up @@ -89,14 +113,28 @@ def __call__(self, retry_state: RetryCallState) -> float:


class PyritException(Exception, ABC):
"""Base exception class for PyRIT components."""

def __init__(self, *, status_code: int = 500, message: str = "An error occurred") -> None:
"""
Initialize a PyritException.

Args:
status_code (int): HTTP-style status code associated with the error.
message (str): Human-readable error description.

"""
self.status_code = status_code
self.message = message
super().__init__(f"Status Code: {status_code}, Message: {message}")

def process_exception(self) -> str:
"""
Logs and returns a string representation of the exception.
Log and return a JSON string representation of the exception.

Returns:
str: Serialized status code and message.

"""
log_message = f"{self.__class__.__name__} encountered: Status Code: {self.status_code}, Message: {self.message}"
logger.error(log_message)
Expand All @@ -108,20 +146,45 @@ class BadRequestException(PyritException):
"""Exception class for bad client requests."""

def __init__(self, *, status_code: int = 400, message: str = "Bad Request") -> None:
"""
Initialize a bad request exception.

Args:
status_code (int): Status code for the error.
message (str): Error message.

"""
super().__init__(status_code=status_code, message=message)


class RateLimitException(PyritException):
"""Exception class for authentication errors."""

def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Exception") -> None:
"""
Initialize a rate limit exception.

Args:
status_code (int): Status code for the error.
message (str): Error message.

"""
super().__init__(status_code=status_code, message=message)


class ServerErrorException(PyritException):
"""Exception class for opaque 5xx errors returned by the server."""

def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None:
"""
Initialize a server error exception.

Args:
status_code (int): Status code for the error.
message (str): Error message.
body (Optional[str]): Optional raw server response body.

"""
super().__init__(status_code=status_code, message=message)
self.body = body

Expand All @@ -130,28 +193,50 @@ class EmptyResponseException(BadRequestException):
"""Exception class for empty response errors."""

def __init__(self, *, status_code: int = 204, message: str = "No Content") -> None:
"""
Initialize an empty response exception.

Args:
status_code (int): Status code for the error.
message (str): Error message.

"""
super().__init__(status_code=status_code, message=message)


class InvalidJsonException(PyritException):
"""Exception class for blocked content errors."""

def __init__(self, *, message: str = "Invalid JSON Response") -> None:
"""
Initialize an invalid JSON exception.

Args:
message (str): Error message.

"""
super().__init__(message=message)


class MissingPromptPlaceholderException(PyritException):
"""Exception class for missing prompt placeholder errors."""

def __init__(self, *, message: str = "No prompt placeholder") -> None:
"""
Initialize a missing placeholder exception.

Args:
message (str): Error message.

"""
super().__init__(message=message)


def pyrit_custom_result_retry(
retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None
) -> Callable[..., Any]:
"""
A decorator to apply retry logic with exponential backoff to a function.
Apply retry logic with exponential backoff to a function.

Retries the function if the result of the retry_function is True,
with a wait time between retries that follows an exponential backoff strategy.
Expand All @@ -162,10 +247,10 @@ def pyrit_custom_result_retry(
on the result of the decorated function.
retry_max_num_attempts (Optional, int): The maximum number of retry attempts. Defaults to
environment variable CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS or 10.
func (Callable): The function to be decorated.

Returns:
Callable: The decorated function with retry logic applied.

"""

def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]:
Expand All @@ -189,7 +274,7 @@ def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]:

def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic with exponential backoff to a function.
Apply retry logic with exponential backoff to a function.

Retries the function if it raises RateLimitError or EmptyResponseException,
with a wait time between retries that follows an exponential backoff strategy.
Expand All @@ -200,6 +285,7 @@ def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]:

Returns:
Callable: The decorated function with retry logic applied.

"""
return retry(
reraise=True,
Expand All @@ -214,7 +300,7 @@ def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]:

def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic to a function.
Apply retry logic to a function.

Retries the function if it raises a JSON error.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.
Expand All @@ -224,6 +310,7 @@ def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]:

Returns:
Callable: The decorated function with retry logic applied.

"""
return retry(
reraise=True,
Expand All @@ -235,7 +322,7 @@ def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]:

def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic.
Apply retry logic.

Retries the function if it raises MissingPromptPlaceholderException.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.
Expand All @@ -245,6 +332,7 @@ def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]:

Returns:
Callable: The decorated function with retry logic applied.

"""
return retry(
reraise=True,
Expand All @@ -260,6 +348,22 @@ def handle_bad_request_exception(
is_content_filter: bool = False,
error_code: int = 400,
) -> Message:
"""
Handle bad request responses and map them to standardized error messages.

Args:
response_text (str): Raw response text from the target.
request (MessagePiece): Original request piece that caused the error.
is_content_filter (bool): Whether the response is known to be content-filtered.
error_code (int): Status code to include in the generated error payload.

Returns:
Message: A constructed error response message.

Raises:
RuntimeError: If the response does not match bad-request content-filter conditions.

"""
if (
"content_filter" in response_text
or "Invalid prompt: your prompt was flagged as potentially violating our usage policy." in response_text
Expand Down
19 changes: 18 additions & 1 deletion pyrit/exceptions/exception_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_retry_context_string(self) -> str:

Returns:
str: A formatted string with component role, component name, and endpoint.

"""
parts = [self.component_role.value]
if self.component_name:
Expand All @@ -99,6 +100,7 @@ def get_exception_details(self) -> str:

Returns:
str: A multi-line formatted string with full context details.

"""
lines = []

Expand Down Expand Up @@ -136,6 +138,7 @@ def get_execution_context() -> Optional[ExecutionContext]:

Returns:
Optional[ExecutionContext]: The current context, or None if not set.

"""
return _execution_context.get()

Expand All @@ -146,6 +149,7 @@ def set_execution_context(context: ExecutionContext) -> None:

Args:
context: The execution context to set.

"""
_execution_context.set(context)

Expand All @@ -171,7 +175,13 @@ class ExecutionContextManager:
_token: Any = field(default=None, init=False, repr=False)

def __enter__(self) -> "ExecutionContextManager":
"""Set the execution context on entry."""
"""
Set the execution context on entry.

Returns:
ExecutionContextManager: The active context manager instance.

"""
self._token = _execution_context.set(self.context)
return self

Expand All @@ -181,6 +191,12 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

If an exception occurred, the context is preserved so that exception
handlers higher in the call stack can access it for enhanced error messages.

Args:
exc_type (Any): Exception type if one was raised, otherwise None.
exc_val (Any): Exception value if one was raised, otherwise None.
exc_tb (Any): Traceback object if an exception was raised, otherwise None.

"""
if exc_type is None:
# No exception - restore previous context
Expand Down Expand Up @@ -210,6 +226,7 @@ def execution_context(

Returns:
ExecutionContextManager: A context manager that sets/clears the context.

"""
# Extract endpoint and component_name from component_identifier if available
endpoint = None
Expand Down
Loading