From 87c42cec1951958a9f5837fec919cc6c5ed61296 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 23 Feb 2026 05:36:14 +0000 Subject: [PATCH 1/3] Complete ruff docstring cleanup for exceptions and models --- pyproject.toml | 2 +- pyrit/exceptions/__init__.py | 2 + pyrit/exceptions/exception_classes.py | 124 +++++++++++- pyrit/exceptions/exception_context.py | 19 +- pyrit/exceptions/exceptions_helpers.py | 13 +- pyrit/models/__init__.py | 2 + pyrit/models/attack_result.py | 8 + pyrit/models/chat_message.py | 16 +- pyrit/models/conversation_reference.py | 17 ++ pyrit/models/data_type_serializer.py | 185 ++++++++++++++++-- pyrit/models/embeddings.py | 29 ++- pyrit/models/harm_definition.py | 6 + pyrit/models/message.py | 167 ++++++++++++++-- pyrit/models/message_piece.py | 47 ++++- pyrit/models/question_answering.py | 30 ++- pyrit/models/scenario_result.py | 29 ++- pyrit/models/score.py | 63 +++++- pyrit/models/seeds/seed.py | 39 +++- pyrit/models/seeds/seed_attack_group.py | 11 +- pyrit/models/seeds/seed_dataset.py | 55 +++++- pyrit/models/seeds/seed_group.py | 77 +++++++- pyrit/models/seeds/seed_objective.py | 1 + pyrit/models/seeds/seed_prompt.py | 12 +- .../seeds/seed_simulated_conversation.py | 29 ++- pyrit/models/storage_io.py | 82 +++++++- pyrit/models/strategy_result.py | 1 + 26 files changed, 981 insertions(+), 85 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7fad04abe..e926e0610b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -282,7 +282,7 @@ notice-rgx = "Copyright \\(c\\) Microsoft Corporation\\.\\s*\\n.*Licensed under # Temporary ignores for pyrit/ subdirectories until issue #1176 # https://github.com/Azure/PyRIT/issues/1176 is fully resolved # TODO: Remove these ignores once the issues are fixed -"pyrit/{auxiliary_attacks,exceptions,models,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"] +"pyrit/{auxiliary_attacks,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"] # Backend API routes raise HTTPException handled by FastAPI, not true exceptions "pyrit/backend/**/*.py" = ["DOC501"] "pyrit/__init__.py" = ["D104"] diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index 47f1d86a7b..9da9650c2d 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +"""Exception classes, retry helpers, and execution context utilities.""" + from pyrit.exceptions.exception_classes import ( BadRequestException, EmptyResponseException, diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 336b221a41..a452e63f58 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -26,22 +26,46 @@ def _get_custom_result_retry_max_num_attempts() -> int: - """Get the maximum number of retry attempts for custom result retry decorator.""" + """ + Get the maximum number of retry attempts for custom result retry decorator. + + Returns: + int: Maximum retry attempts. + + """ return int(os.getenv("CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS", 10)) def get_retry_max_num_attempts() -> int: - """Get the maximum number of retry attempts.""" + """ + Get the maximum number of retry attempts. + + Returns: + int: Maximum retry attempts. + + """ return int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 10)) def _get_retry_wait_min_seconds() -> int: - """Get the minimum wait time in seconds between retries.""" + """ + Get the minimum wait time in seconds between retries. + + Returns: + int: Minimum wait duration in seconds. + + """ return int(os.getenv("RETRY_WAIT_MIN_SECONDS", 5)) def _get_retry_wait_max_seconds() -> int: - """Get the maximum wait time in seconds between retries.""" + """ + Get the maximum wait time in seconds between retries. + + Returns: + int: Maximum wait duration in seconds. + + """ return int(os.getenv("RETRY_WAIT_MAX_SECONDS", 220)) @@ -89,14 +113,28 @@ def __call__(self, retry_state: RetryCallState) -> float: class PyritException(Exception, ABC): + """Base exception class for PyRIT components.""" + def __init__(self, *, status_code: int = 500, message: str = "An error occurred") -> None: + """ + Initialize a PyritException. + + Args: + status_code (int): HTTP-style status code associated with the error. + message (str): Human-readable error description. + + """ self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}") def process_exception(self) -> str: """ - Logs and returns a string representation of the exception. + Log and return a JSON string representation of the exception. + + Returns: + str: Serialized status code and message. + """ log_message = f"{self.__class__.__name__} encountered: Status Code: {self.status_code}, Message: {self.message}" logger.error(log_message) @@ -108,6 +146,14 @@ class BadRequestException(PyritException): """Exception class for bad client requests.""" def __init__(self, *, status_code: int = 400, message: str = "Bad Request") -> None: + """ + Initialize a bad request exception. + + Args: + status_code (int): Status code for the error. + message (str): Error message. + + """ super().__init__(status_code=status_code, message=message) @@ -115,6 +161,14 @@ class RateLimitException(PyritException): """Exception class for authentication errors.""" def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Exception") -> None: + """ + Initialize a rate limit exception. + + Args: + status_code (int): Status code for the error. + message (str): Error message. + + """ super().__init__(status_code=status_code, message=message) @@ -122,6 +176,15 @@ class ServerErrorException(PyritException): """Exception class for opaque 5xx errors returned by the server.""" def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None: + """ + Initialize a server error exception. + + Args: + status_code (int): Status code for the error. + message (str): Error message. + body (Optional[str]): Optional raw server response body. + + """ super().__init__(status_code=status_code, message=message) self.body = body @@ -130,6 +193,14 @@ class EmptyResponseException(BadRequestException): """Exception class for empty response errors.""" def __init__(self, *, status_code: int = 204, message: str = "No Content") -> None: + """ + Initialize an empty response exception. + + Args: + status_code (int): Status code for the error. + message (str): Error message. + + """ super().__init__(status_code=status_code, message=message) @@ -137,6 +208,13 @@ class InvalidJsonException(PyritException): """Exception class for blocked content errors.""" def __init__(self, *, message: str = "Invalid JSON Response") -> None: + """ + Initialize an invalid JSON exception. + + Args: + message (str): Error message. + + """ super().__init__(message=message) @@ -144,6 +222,13 @@ class MissingPromptPlaceholderException(PyritException): """Exception class for missing prompt placeholder errors.""" def __init__(self, *, message: str = "No prompt placeholder") -> None: + """ + Initialize a missing placeholder exception. + + Args: + message (str): Error message. + + """ super().__init__(message=message) @@ -151,7 +236,7 @@ def pyrit_custom_result_retry( retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None ) -> Callable[..., Any]: """ - A decorator to apply retry logic with exponential backoff to a function. + Apply retry logic with exponential backoff to a function. Retries the function if the result of the retry_function is True, with a wait time between retries that follows an exponential backoff strategy. @@ -162,10 +247,10 @@ def pyrit_custom_result_retry( on the result of the decorated function. retry_max_num_attempts (Optional, int): The maximum number of retry attempts. Defaults to environment variable CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS or 10. - func (Callable): The function to be decorated. Returns: Callable: The decorated function with retry logic applied. + """ def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]: @@ -189,7 +274,7 @@ def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]: def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ - A decorator to apply retry logic with exponential backoff to a function. + Apply retry logic with exponential backoff to a function. Retries the function if it raises RateLimitError or EmptyResponseException, with a wait time between retries that follows an exponential backoff strategy. @@ -200,6 +285,7 @@ def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]: Returns: Callable: The decorated function with retry logic applied. + """ return retry( reraise=True, @@ -214,7 +300,7 @@ def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]: def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ - A decorator to apply retry logic to a function. + Apply retry logic to a function. Retries the function if it raises a JSON error. Logs retry attempts at the INFO level and stops after a maximum number of attempts. @@ -224,6 +310,7 @@ def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]: Returns: Callable: The decorated function with retry logic applied. + """ return retry( reraise=True, @@ -235,7 +322,7 @@ def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]: def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ - A decorator to apply retry logic. + Apply retry logic. Retries the function if it raises MissingPromptPlaceholderException. Logs retry attempts at the INFO level and stops after a maximum number of attempts. @@ -245,6 +332,7 @@ def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]: Returns: Callable: The decorated function with retry logic applied. + """ return retry( reraise=True, @@ -260,6 +348,22 @@ def handle_bad_request_exception( is_content_filter: bool = False, error_code: int = 400, ) -> Message: + """ + Handle bad request responses and map them to standardized error messages. + + Args: + response_text (str): Raw response text from the target. + request (MessagePiece): Original request piece that caused the error. + is_content_filter (bool): Whether the response is known to be content-filtered. + error_code (int): Status code to include in the generated error payload. + + Returns: + Message: A constructed error response message. + + Raises: + RuntimeError: If the response does not match bad-request content-filter conditions. + + """ if ( "content_filter" in response_text or "Invalid prompt: your prompt was flagged as potentially violating our usage policy." in response_text diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index b88c92a017..7dff15b312 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -85,6 +85,7 @@ def get_retry_context_string(self) -> str: Returns: str: A formatted string with component role, component name, and endpoint. + """ parts = [self.component_role.value] if self.component_name: @@ -99,6 +100,7 @@ def get_exception_details(self) -> str: Returns: str: A multi-line formatted string with full context details. + """ lines = [] @@ -136,6 +138,7 @@ def get_execution_context() -> Optional[ExecutionContext]: Returns: Optional[ExecutionContext]: The current context, or None if not set. + """ return _execution_context.get() @@ -146,6 +149,7 @@ def set_execution_context(context: ExecutionContext) -> None: Args: context: The execution context to set. + """ _execution_context.set(context) @@ -171,7 +175,13 @@ class ExecutionContextManager: _token: Any = field(default=None, init=False, repr=False) def __enter__(self) -> "ExecutionContextManager": - """Set the execution context on entry.""" + """ + Set the execution context on entry. + + Returns: + ExecutionContextManager: The active context manager instance. + + """ self._token = _execution_context.set(self.context) return self @@ -181,6 +191,12 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: If an exception occurred, the context is preserved so that exception handlers higher in the call stack can access it for enhanced error messages. + + Args: + exc_type (Any): Exception type if one was raised, otherwise None. + exc_val (Any): Exception value if one was raised, otherwise None. + exc_tb (Any): Traceback object if an exception was raised, otherwise None. + """ if exc_type is None: # No exception - restore previous context @@ -210,6 +226,7 @@ def execution_context( Returns: ExecutionContextManager: A context manager that sets/clears the context. + """ # Extract endpoint and component_name from component_identifier if available endpoint = None diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index c0f2e69c7e..8062c47eaf 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -22,6 +22,7 @@ def log_exception(retry_state: RetryCallState) -> None: Args: retry_state: The tenacity retry state containing attempt information. + """ # Validate retry_state has required attributes before proceeding if not retry_state: @@ -67,13 +68,14 @@ def log_exception(retry_state: RetryCallState) -> None: def remove_start_md_json(response_msg: str) -> str: """ - Checks the message for the listed start patterns and removes them if present. + Check the message for supported start markers and remove them if present. Args: response_msg (str): The response message to check. Returns: str: The response message without the start marker (if one was present). + """ start_pattern = re.compile(r"^(```json\n|`json\n|```\n|`\n|```json|`json|```|`|json|json\n)") match = start_pattern.match(response_msg) @@ -85,13 +87,14 @@ def remove_start_md_json(response_msg: str) -> str: def remove_end_md_json(response_msg: str) -> str: """ - Checks the message for the listed end patterns and removes them if present. + Check the message for supported end markers and remove them if present. Args: response_msg (str): The response message to check. Returns: str: The response message without the end marker (if one was present). + """ end_pattern = re.compile(r"(\n```|\n`|```|`)$") match = end_pattern.search(response_msg) @@ -103,13 +106,14 @@ def remove_end_md_json(response_msg: str) -> str: def extract_json_from_string(response_msg: str) -> str: """ - Attempts to extract JSON (object or array) from within a larger string, not specific to markdown. + Extract JSON (object or array) from within a larger string. Args: response_msg (str): The response message to check. Returns: str: The extracted JSON string if found, otherwise the original string. + """ json_pattern = re.compile(r"\{.*\}|\[.*\]") match = json_pattern.search(response_msg) @@ -121,13 +125,14 @@ def extract_json_from_string(response_msg: str) -> str: def remove_markdown_json(response_msg: str) -> str: """ - Checks if the response message is in JSON format and removes Markdown formatting if present. + Remove markdown wrappers and return a JSON payload when possible. Args: response_msg (str): The response message to check. Returns: str: The response message without Markdown formatting if present. + """ response_msg = remove_start_md_json(response_msg) response_msg = remove_end_md_json(response_msg) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index c5851f2cc5..9f7994f968 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +"""Public model exports for PyRIT core data structures and helpers.""" + from pyrit.models.attack_result import AttackOutcome, AttackResult, AttackResultT from pyrit.models.chat_message import ( ALLOWED_CHAT_MESSAGE_ROLES, diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 3518ce4626..2b14029ce8 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -84,8 +84,16 @@ def get_conversations_by_type(self, conversation_type: ConversationType) -> list Returns: list: A list of related conversations matching the specified type. + """ return [ref for ref in self.related_conversations if ref.conversation_type == conversation_type] def __str__(self) -> str: + """ + Return a concise string representation of this attack result. + + Returns: + str: Summary containing conversation ID, outcome, and objective preview. + + """ return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index 7b8a5c0002..bd05fd83c3 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -12,6 +12,8 @@ class ToolCall(BaseModel): + """Represents a tool invocation requested by the assistant.""" + model_config = ConfigDict(extra="forbid") id: str type: str @@ -40,6 +42,7 @@ def to_json(self) -> str: Returns: A JSON string representation of the message. + """ return self.model_dump_json() @@ -49,6 +52,7 @@ def to_dict(self) -> dict[str, Any]: Returns: A dictionary representation of the message, excluding None values. + """ return self.model_dump(exclude_none=True) @@ -62,6 +66,7 @@ def from_json(cls, json_str: str) -> "ChatMessage": Returns: A ChatMessage instance. + """ return cls.model_validate_json(json_str) @@ -74,6 +79,13 @@ class ChatMessageListDictContent(ChatMessage): """ def __init__(self, **data: Any) -> None: + """ + Initialize a deprecated compatibility wrapper around ChatMessage. + + Args: + **data (Any): Keyword arguments accepted by ChatMessage. + + """ print_deprecation_message( old_item="ChatMessageListDictContent", new_item="ChatMessage", @@ -86,11 +98,13 @@ class ChatMessagesDataset(BaseModel): """ Represents a dataset of chat messages. - Parameters: + Parameters + ---------- model_config (ConfigDict): The model configuration. name (str): The name of the dataset. description (str): The description of the dataset. list_of_chat_messages (list[list[ChatMessage]]): A list of chat messages. + """ model_config = ConfigDict(extra="forbid") diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index eccdb70cde..0932cca051 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -27,7 +27,24 @@ class ConversationReference: # Allow use in set / dict def __hash__(self) -> int: + """ + Return a hash derived from conversation ID. + + Returns: + int: Hash value for this reference. + + """ return hash(self.conversation_id) def __eq__(self, other: object) -> bool: + """ + Compare two references by conversation ID. + + Args: + other (object): Other object to compare. + + Returns: + bool: True when the other object is a matching ConversationReference. + + """ return isinstance(other, ConversationReference) and self.conversation_id == other.conversation_id diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 4ace4e2bba..da51674d0d 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -35,7 +35,7 @@ def data_serializer_factory( category: AllowedCategories, ) -> "DataTypeSerializer": """ - Factory method to create a DataTypeSerializer instance. + Create a DataTypeSerializer instance. Args: data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). @@ -48,6 +48,7 @@ def data_serializer_factory( Raises: ValueError: If the category is not provided or invalid. + """ if not category: raise ValueError( @@ -115,6 +116,7 @@ def _get_storage_io(self) -> StorageIO: Raises: ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. + """ if self._is_azure_storage_url(self.value): # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact @@ -125,16 +127,21 @@ def _get_storage_io(self) -> StorageIO: @abc.abstractmethod def data_on_disk(self) -> bool: """ - Returns True if the data is stored on disk. + Indicate whether the data is stored on disk. + + Returns: + bool: True when data is persisted on disk. + """ async def save_data(self, data: bytes, output_filename: Optional[str] = None) -> None: """ - Saves the data to storage. + Save data to storage. Arguments: data: bytes: The data to be saved. output_filename (optional, str): filename to store data as. Defaults to UUID if not provided + """ file_path = await self.get_data_filename(file_name=output_filename) await self._memory.results_storage_io.write_file(file_path, data) @@ -142,11 +149,12 @@ async def save_data(self, data: bytes, output_filename: Optional[str] = None) -> async def save_b64_image(self, data: str | bytes, output_filename: str = None) -> None: """ - Saves the base64 encoded image to storage. + Save a base64-encoded image to storage. Arguments: data: string or bytes with base64 data output_filename (optional, str): filename to store image as. Defaults to UUID if not provided + """ file_path = await self.get_data_filename(file_name=output_filename) image_bytes = base64.b64decode(data) @@ -162,7 +170,7 @@ async def save_formatted_audio( output_filename: Optional[str] = None, ) -> None: """ - Saves the PCM16 of other specially formatted audio data to storage. + Save PCM16 or similarly formatted audio data to storage. Arguments: data: bytes with audio data @@ -170,6 +178,7 @@ async def save_formatted_audio( num_channels (optional, int): number of channels in audio data. Defaults to 1 sample_width (optional, int): sample width in bytes. Defaults to 2 sample_rate (optional, int): sample rate in Hz. Defaults to 16000 + """ file_path = await self.get_data_filename(file_name=output_filename) @@ -199,10 +208,16 @@ async def save_formatted_audio( async def read_data(self) -> bytes: """ - Reads the data from the storage. + Read data from storage. Returns: bytes: The data read from storage. + + Raises: + TypeError: If the serializer does not represent on-disk data. + RuntimeError: If no value is set. + FileNotFoundError: If the referenced file does not exist. + """ if not self.data_on_disk(): raise TypeError(f"Data for data Type {self.data_type} is not stored on disk") @@ -220,12 +235,27 @@ async def read_data(self) -> bytes: async def read_data_base64(self) -> str: """ - Reads the data from the storage. + Read data from storage and return it as a base64 string. + + Returns: + str: Base64-encoded data. + """ byte_array = await self.read_data() return base64.b64encode(byte_array).decode("utf-8") async def get_sha256(self) -> str: + """ + Compute SHA256 hash for this serializer's current value. + + Returns: + str: Hex digest of the computed SHA256 hash. + + Raises: + FileNotFoundError: If on-disk data path does not exist. + ValueError: If in-memory data cannot be converted to bytes. + + """ input_bytes: bytes = None if self.data_on_disk(): @@ -247,7 +277,18 @@ async def get_sha256(self) -> str: async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path, str]: """ - Generates or retrieves a unique filename for the data file. + Generate or retrieve a unique filename for the data file. + + Args: + file_name (Optional[str]): Optional file name override. + + Returns: + Union[Path, str]: Full storage path for the generated data file. + + Raises: + TypeError: If the serializer is not configured for on-disk data. + RuntimeError: If required data subdirectory information is missing. + """ if self._file_path: return self._file_path @@ -276,6 +317,13 @@ async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path def get_extension(file_path: str) -> str | None: """ Get the file extension from the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: File extension (including dot) or None if unavailable. + """ _, ext = os.path.splitext(file_path) return ext if ext else None @@ -284,44 +332,96 @@ def get_extension(file_path: str) -> str | None: def get_mime_type(file_path: str) -> str | None: """ Get the MIME type of the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: MIME type if detectable; otherwise None. + """ mime_type, _ = guess_type(file_path) return mime_type def _is_azure_storage_url(self, path: str) -> bool: """ - Validates if the given path is an Azure Storage URL. + Validate whether the given path is an Azure Storage URL. Args: path (str): Path or URL to check. Returns: bool: True if the path is an Azure Blob Storage URL. + """ parsed = urlparse(path) return parsed.scheme in ("http", "https") and "blob.core.windows.net" in parsed.netloc class TextDataTypeSerializer(DataTypeSerializer): + """Serializer for text and text-like prompt values that stay in-memory.""" + def __init__(self, *, prompt_text: str, data_type: PromptDataType = "text"): + """ + Initialize a text serializer. + + Args: + prompt_text (str): Prompt value. + data_type (PromptDataType): Text-like prompt data type. + + """ self.data_type = data_type self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for text serializers. + + """ return False class ErrorDataTypeSerializer(DataTypeSerializer): + """Serializer for error payloads stored as in-memory text.""" + def __init__(self, *, prompt_text: str): + """ + Initialize an error serializer. + + Args: + prompt_text (str): Error payload text. + + """ self.data_type = "error" self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for error serializers. + + """ return False class URLDataTypeSerializer(DataTypeSerializer): + """Serializer for URL values and URL-backed local file references.""" + def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None): + """ + Initialize a URL serializer. + + Args: + category (str): Data category folder name. + prompt_text (str): URL or path value. + extension (Optional[str]): Optional extension for persisted content. + + """ self.data_type = "url" self.value = prompt_text self.data_sub_directory = f"/{category}/urls" @@ -329,11 +429,29 @@ def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] self.on_disk = not (prompt_text.startswith("http://") or prompt_text.startswith("https://")) def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: True for non-http values, False for URL values. + + """ return self.on_disk class ImagePathDataTypeSerializer(DataTypeSerializer): + """Serializer for image path values stored on disk.""" + def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None): + """ + Initialize an image-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing image path. + extension (Optional[str]): Optional image extension. + + """ self.data_type = "image_path" self.data_sub_directory = f"/{category}/images" self.file_extension = extension if extension else "png" @@ -342,10 +460,19 @@ def __init__(self, *, category: str, prompt_text: Optional[str] = None, extensio self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for image path serializers. + + """ return True class AudioPathDataTypeSerializer(DataTypeSerializer): + """Serializer for audio path values stored on disk.""" + def __init__( self, *, @@ -353,6 +480,15 @@ def __init__( prompt_text: Optional[str] = None, extension: Optional[str] = None, ): + """ + Initialize an audio-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing audio path. + extension (Optional[str]): Optional audio extension. + + """ self.data_type = "audio_path" self.data_sub_directory = f"/{category}/audio" self.file_extension = extension if extension else "mp3" @@ -361,10 +497,19 @@ def __init__( self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for audio path serializers. + + """ return True class VideoPathDataTypeSerializer(DataTypeSerializer): + """Serializer for video path values stored on disk.""" + def __init__( self, *, @@ -373,12 +518,13 @@ def __init__( extension: Optional[str] = None, ): """ - Serializer for video data paths. + Initialize a video-path serializer. Args: category (str): The category or context for the data. prompt_text (Optional[str]): The video path or identifier. extension (Optional[str]): The file extension, defaults to 'mp4'. + """ self.data_type = "video_path" self.data_sub_directory = f"/{category}/videos" @@ -388,10 +534,19 @@ def __init__( self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for video path serializers. + + """ return True class BinaryPathDataTypeSerializer(DataTypeSerializer): + """Serializer for generic binary path values stored on disk.""" + def __init__( self, *, @@ -400,7 +555,7 @@ def __init__( extension: Optional[str] = None, ): """ - Serializer for arbitrary binary data paths. + Initialize a generic binary-path serializer. This serializer handles generic binary data that doesn't fit into specific categories like images, audio, or video. Useful for XPIA attacks and @@ -410,6 +565,7 @@ def __init__( category (str): The category or context for the data. prompt_text (Optional[str]): The binary file path or identifier. extension (Optional[str]): The file extension, defaults to 'bin'. + """ self.data_type = "binary_path" self.data_sub_directory = f"/{category}/binaries" @@ -419,4 +575,11 @@ def __init__( self.value = prompt_text def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for binary path serializers. + + """ return True diff --git a/pyrit/models/embeddings.py b/pyrit/models/embeddings.py index a19bf7edf2..0fb7c36529 100644 --- a/pyrit/models/embeddings.py +++ b/pyrit/models/embeddings.py @@ -11,12 +11,16 @@ class EmbeddingUsageInformation(BaseModel): + """Token usage metadata returned by an embedding API.""" + model_config = ConfigDict(extra="forbid") prompt_tokens: int total_tokens: int class EmbeddingData(BaseModel): + """Single embedding vector payload with index and object metadata.""" + model_config = ConfigDict(extra="forbid") embedding: list[float] index: int @@ -24,6 +28,8 @@ class EmbeddingData(BaseModel): class EmbeddingResponse(BaseModel): + """Embedding API response containing vectors, model metadata, and usage.""" + model_config = ConfigDict(extra="forbid") model: str object: str @@ -35,9 +41,11 @@ def save_to_file(self, directory_path: Path) -> str: Save the embedding response to disk and return the path of the new file. Args: - directory_path: The path to save the file to + directory_path: The path to save the file to. + Returns: - The full path to the file that was saved + str: The full path to the file that was saved. + """ embedding_json = self.model_dump_json() embedding_hash = sha256(embedding_json.encode()).hexdigest() @@ -51,18 +59,29 @@ def load_from_file(file_path: Path) -> EmbeddingResponse: Load the embedding response from disk. Args: - file_path: The path to load the file from + file_path: The path to load the file from. + Returns: - The loaded embedding response + EmbeddingResponse: The loaded embedding response. + """ embedding_json_data = file_path.read_text(encoding="utf-8") return EmbeddingResponse.model_validate_json(embedding_json_data) def to_json(self) -> str: + """ + Serialize this embedding response to JSON. + + Returns: + str: JSON-encoded embedding response. + + """ return self.model_dump_json() class EmbeddingSupport(ABC): + """Protocol-like interface for classes that generate text embeddings.""" + @abstractmethod def generate_text_embedding(self, text: str, **kwargs: object) -> EmbeddingResponse: """ @@ -74,6 +93,7 @@ def generate_text_embedding(self, text: str, **kwargs: object) -> EmbeddingRespo Returns: The embedding response + """ raise NotImplementedError("generate_text_embedding method not implemented") @@ -88,5 +108,6 @@ async def generate_text_embedding_async(self, text: str, **kwargs: object) -> Em Returns: The embedding response + """ raise NotImplementedError("generate_text_embedding_async method not implemented") diff --git a/pyrit/models/harm_definition.py b/pyrit/models/harm_definition.py index 8974cb853e..264e202350 100644 --- a/pyrit/models/harm_definition.py +++ b/pyrit/models/harm_definition.py @@ -28,6 +28,7 @@ class ScaleDescription: Args: score_value: The score value (e.g., "1", "2", etc.) description: The description for this score level. + """ score_value: str @@ -48,6 +49,7 @@ class HarmDefinition: category: The harm category name (e.g., "violence", "hate_speech"). scale_descriptions: List of scale descriptions defining score levels. source_path: The path to the YAML file this was loaded from. + """ version: str @@ -64,6 +66,7 @@ def get_scale_description(self, score_value: str) -> Optional[str]: Returns: The description for the score value, or None if not found. + """ for scale in self.scale_descriptions: if scale.score_value == score_value: @@ -87,6 +90,7 @@ def validate_category(category: str, *, check_exists: bool = False) -> bool: Returns: True if the category is valid (and exists if check_exists is True), False otherwise. + """ # Check if category matches pattern: only lowercase letters and underscores if not re.match(r"^[a-z_]+$", category): @@ -119,6 +123,7 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": Raises: FileNotFoundError: If the harm definition file does not exist. ValueError: If the YAML file is invalid or missing required fields. + """ path = Path(harm_definition_path) @@ -187,6 +192,7 @@ def get_all_harm_definitions() -> Dict[str, HarmDefinition]: Raises: ValueError: If any YAML file in the directory is invalid. + """ harm_definitions: Dict[str, HarmDefinition] = {} diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 509d70cb29..0ba939e5cd 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -20,11 +20,24 @@ class Message: This is a single request to a target. It can contain multiple message pieces. - Parameters: + Parameters + ---------- message_pieces (Sequence[MessagePiece]): The list of message pieces. + """ def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False) -> None: + """ + Initialize a Message from one or more message pieces. + + Args: + message_pieces (Sequence[MessagePiece]): Pieces belonging to the same message turn. + skip_validation (Optional[bool]): Whether to skip consistency validation. + + Raises: + ValueError: If no message pieces are provided. + + """ if not message_pieces: raise ValueError("Message must have at least one message piece.") self.message_pieces = message_pieces @@ -32,17 +45,48 @@ def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: O self.validate() def get_value(self, n: int = 0) -> str: - """Return the converted value of the nth message piece.""" + """ + Return the converted value of the nth message piece. + + Args: + n (int): Zero-based index of the piece to read. + + Returns: + str: Converted value of the selected message piece. + + Raises: + IndexError: If the index is out of bounds. + + """ if n >= len(self.message_pieces): raise IndexError(f"No message piece at index {n}.") return self.message_pieces[n].converted_value def get_values(self) -> list[str]: - """Return the converted values of all message pieces.""" + """ + Return the converted values of all message pieces. + + Returns: + list[str]: Converted values for all message pieces. + + """ return [message_piece.converted_value for message_piece in self.message_pieces] def get_piece(self, n: int = 0) -> MessagePiece: - """Return the nth message piece.""" + """ + Return the nth message piece. + + Args: + n (int): Zero-based index of the piece to return. + + Returns: + MessagePiece: Selected message piece. + + Raises: + ValueError: If the message has no pieces. + IndexError: If the index is out of bounds. + + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") @@ -68,6 +112,7 @@ def get_pieces_by_type( Returns: A list of matching MessagePiece objects (may be empty). + """ effective_converted = converted_value_data_type or data_type results = self.message_pieces @@ -94,6 +139,7 @@ def get_piece_by_type( Returns: The first matching MessagePiece, or None if no match is found. + """ pieces = self.get_pieces_by_type( data_type=data_type, @@ -109,6 +155,13 @@ def api_role(self) -> ChatMessageRole: Maps simulated_assistant to assistant for API compatibility. All message pieces in a Message should have the same role. + + Returns: + ChatMessageRole: Role compatible with external API calls. + + Raises: + ValueError: If the message has no pieces. + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") @@ -143,21 +196,43 @@ def role(self) -> ChatMessageRole: @property def conversation_id(self) -> str: - """Return the conversation ID of the first request piece (they should all be the same).""" + """ + Return the conversation ID of the first request piece. + + Returns: + str: Conversation identifier. + + Raises: + ValueError: If the message has no pieces. + + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") return self.message_pieces[0].conversation_id @property def sequence(self) -> int: - """Return the sequence of the first request piece (they should all be the same).""" + """ + Return the sequence value of the first request piece. + + Returns: + int: Sequence number for the message turn. + + Raises: + ValueError: If the message has no pieces. + + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") return self.message_pieces[0].sequence def is_error(self) -> bool: """ - Returns True if any of the message pieces have an error response. + Check whether any message piece indicates an error. + + Returns: + bool: True when any piece has a non-none error flag or error data type. + """ for piece in self.message_pieces: if piece.response_error != "none" or piece.converted_value_data_type == "error": @@ -186,7 +261,12 @@ def set_simulated_role(self) -> None: def validate(self) -> None: """ - Validates the request response. + Validate that all message pieces are internally consistent. + + Raises: + ValueError: If piece collection is empty or contains mismatched conversation IDs, + sequence numbers, roles, or missing converted values. + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") @@ -208,6 +288,13 @@ def validate(self) -> None: raise ValueError("Inconsistent roles within the same message entry.") def __str__(self) -> str: + """ + Return a newline-delimited string representation of message pieces. + + Returns: + str: Concatenated string representation. + + """ ret = "" for message_piece in self.message_pieces: ret += str(message_piece) + "\n" @@ -220,6 +307,7 @@ def to_dict(self) -> dict[str, object]: Returns: dict: A dictionary with 'role', 'converted_value', 'conversation_id', 'sequence', and 'converted_value_data_type' keys. + """ if len(self.message_pieces) == 1: converted_value: str | list[str] = self.message_pieces[0].converted_value @@ -238,7 +326,16 @@ def to_dict(self) -> dict[str, object]: @staticmethod def get_all_values(messages: Sequence[Message]) -> list[str]: - """Return all converted values across the provided messages.""" + """ + Return all converted values across the provided messages. + + Args: + messages (Sequence[Message]): Messages to aggregate. + + Returns: + list[str]: Flattened list of converted values. + + """ values: list[str] = [] for message in messages: values.extend(message.get_values()) @@ -248,6 +345,16 @@ def get_all_values(messages: Sequence[Message]) -> list[str]: def flatten_to_message_pieces( messages: Sequence[Message], ) -> MutableSequence[MessagePiece]: + """ + Flatten messages into a single list of message pieces. + + Args: + messages (Sequence[Message]): Messages to flatten. + + Returns: + MutableSequence[MessagePiece]: Flattened message pieces. + + """ if not messages: return [] message_pieces: MutableSequence[MessagePiece] = [] @@ -265,11 +372,33 @@ def from_prompt( role: ChatMessageRole, prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, ) -> Message: + """ + Build a single-piece message from prompt text. + + Args: + prompt (str): Prompt text. + role (ChatMessageRole): Role assigned to the message piece. + prompt_metadata (Optional[Dict[str, Union[str, int]]]): Optional prompt metadata. + + Returns: + Message: Constructed message instance. + + """ piece = MessagePiece(original_value=prompt, role=role, prompt_metadata=prompt_metadata) return cls(message_pieces=[piece]) @classmethod def from_system_prompt(cls, system_prompt: str) -> Message: + """ + Build a message from a system prompt. + + Args: + system_prompt (str): System instruction text. + + Returns: + Message: Constructed system-role message. + + """ return cls.from_prompt(prompt=system_prompt, role="system") def duplicate_message(self) -> Message: @@ -284,6 +413,7 @@ def duplicate_message(self) -> Message: Returns: Message: A new Message with deep-copied message pieces, new IDs, and fresh timestamp. + """ new_pieces = copy.deepcopy(self.message_pieces) new_timestamp = datetime.now() @@ -298,7 +428,7 @@ def group_conversation_message_pieces_by_sequence( message_pieces: Sequence[MessagePiece], ) -> MutableSequence[Message]: """ - Groups message pieces from the same conversation into Messages. + Group message pieces from the same conversation into messages. This is done using the sequence number and conversation ID. @@ -334,6 +464,7 @@ def group_conversation_message_pieces_by_sequence( ... MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!") ... ]) ... ] + """ if not message_pieces: return [] @@ -362,7 +493,7 @@ def group_message_pieces_into_conversations( message_pieces: Sequence[MessagePiece], ) -> list[list[Message]]: """ - Groups message pieces from multiple conversations into separate conversation groups. + Group message pieces from multiple conversations into separate conversation groups. This function first groups pieces by conversation ID, then groups each conversation's pieces by sequence number. Each conversation is returned as a separate list of @@ -389,6 +520,7 @@ def group_message_pieces_into_conversations( >>> # [Message(seq=1), Message(seq=2)], # conv1 >>> # [Message(seq=1), Message(seq=2)] # conv2 >>> # ] + """ if not message_pieces: return [] @@ -418,7 +550,18 @@ def construct_response_from_request( error: PromptResponseError = "none", ) -> Message: """ - Constructs a response entry from a request. + Construct a response message from a request message piece. + + Args: + request (MessagePiece): Source request message piece. + response_text_pieces (list[str]): Response values to include. + response_type (PromptDataType): Data type for original and converted response values. + prompt_metadata (Optional[Dict[str, Union[str, int]]]): Additional metadata to merge. + error (PromptResponseError): Error classification for the response. + + Returns: + Message: Constructed response message. + """ if request.prompt_metadata: prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {}) diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index f07b045318..7c60f0b688 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -82,6 +82,10 @@ def __init__( timestamp: The timestamp of the memory entry. Defaults to None (auto-generated). scores: The scores associated with the prompt. Defaults to None. targeted_harm_categories: The harm categories associated with the prompt. Defaults to None. + + Raises: + ValueError: If role, data types, or response error are invalid. + """ self.id = id if id else uuid4() @@ -160,7 +164,7 @@ def __init__( async def set_sha256_values_async(self) -> None: """ - This method computes the SHA256 hash values asynchronously. + Compute SHA256 hash values for original and converted payloads. It should be called after object creation if `original_value` and `converted_value` are set. Note, this method is async due to the blob retrieval. And because of that, we opted @@ -211,6 +215,7 @@ def get_role_for_storage(self) -> ChatMessageRole: Returns: The actual role stored (may be simulated_assistant). + """ return self._role @@ -242,6 +247,7 @@ def role(self, value: ChatMessageRole) -> None: Raises: ValueError: If the role is not a valid ChatMessageRole. + """ if value not in ChatMessageRole.__args__: # type: ignore raise ValueError(f"Role {value} is not a valid role.") @@ -255,12 +261,20 @@ def to_message(self) -> Message: # type: ignore # noqa F821 def has_error(self) -> bool: """ Check if the message piece has an error. + + Returns: + bool: True when the response_error is not "none". + """ return self.response_error != "none" def is_blocked(self) -> bool: """ Check if the message piece is blocked. + + Returns: + bool: True when the response_error is "blocked". + """ return self.response_error == "blocked" @@ -273,6 +287,13 @@ def set_piece_not_in_database(self) -> None: self.id = None def to_dict(self) -> dict[str, object]: + """ + Convert this message piece to a dictionary representation. + + Returns: + dict[str, object]: Dictionary representation suitable for serialization. + + """ return { "id": str(self.id), "role": self._role, @@ -301,12 +322,29 @@ def to_dict(self) -> dict[str, object]: } def __str__(self) -> str: + """ + Return a concise string representation of this message piece. + + Returns: + str: Target, role, and converted value summary. + + """ target_str = self.prompt_target_identifier.class_name if self.prompt_target_identifier else "Unknown" return f"{target_str}: {self._role}: {self.converted_value}" __repr__ = __str__ def __eq__(self, other: object) -> bool: + """ + Compare this message piece with another for semantic equality. + + Args: + other (object): Object to compare. + + Returns: + bool: True when all relevant message fields match. + + """ if not isinstance(other, MessagePiece): return NotImplemented return ( @@ -328,6 +366,13 @@ def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece Group by conversation_id. Order conversations by the earliest timestamp within each conversation_id. Within each conversation, order messages by sequence. + + Args: + message_pieces (list[MessagePiece]): Message pieces to sort. + + Returns: + list[MessagePiece]: Sorted message pieces. + """ earliest_timestamps = { convo_id: min(x.timestamp for x in message_pieces if x.conversation_id == convo_id) diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index 614bc2022a..610eb7ae37 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -12,9 +12,11 @@ class QuestionChoice(BaseModel): """ Represents a choice for a question. - Parameters: + Parameters + ---------- index (int): The index of the choice. text (str): The text of the choice. + """ model_config = ConfigDict(extra="forbid") @@ -26,7 +28,8 @@ class QuestionAnsweringEntry(BaseModel): """ Represents a question model. - Parameters: + Parameters + ---------- question (str): The question text. answer_type (Literal["int", "float", "str", "bool"]): The type of the answer. `int` for integer answers (e.g., when the answer is an index of the correct option in a multiple-choice @@ -36,6 +39,7 @@ class QuestionAnsweringEntry(BaseModel): `bool` for boolean answers. correct_answer (Union[int, str, float]): The correct answer. choices (list[QuestionChoice]): The list of choices for the question. + """ model_config = ConfigDict(extra="forbid") @@ -45,7 +49,16 @@ class QuestionAnsweringEntry(BaseModel): choices: list[QuestionChoice] def get_correct_answer_text(self) -> str: - """Get the text of the correct answer.""" + """ + Get the text of the correct answer. + + Returns: + str: Text corresponding to the configured correct answer index. + + Raises: + ValueError: If no choice matches the configured correct answer. + + """ correct_answer_index = self.correct_answer try: # Match using the explicit choice.index (not enumerate position) so non-sequential indices are supported @@ -57,6 +70,13 @@ def get_correct_answer_text(self) -> str: ) def __hash__(self) -> int: + """ + Return a stable hash for this question entry. + + Returns: + int: Hash computed from serialized model content. + + """ return hash(self.model_dump_json()) @@ -64,7 +84,8 @@ class QuestionAnsweringDataset(BaseModel): """ Represents a dataset for question answering. - Parameters: + Parameters + ---------- name (str): The name of the dataset. version (str): The version of the dataset. description (str): A description of the dataset. @@ -72,6 +93,7 @@ class QuestionAnsweringDataset(BaseModel): group (str): The group associated with the dataset. source (str): The source of the dataset. questions (list[QuestionAnsweringEntry]): A list of question models. + """ model_config = ConfigDict(extra="forbid") diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index ab0ba0e5e6..45570a284b 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -39,6 +39,7 @@ def __init__( scenario_version (int): Version of the scenario. init_data (Optional[dict]): Initialization data. pyrit_version (Optional[str]): PyRIT version string. If None, uses current version. + """ self.name = name self.description = description @@ -70,6 +71,22 @@ def __init__( # Deprecated parameter - will be removed in 0.13.0 objective_scorer: Optional["Scorer"] = None, ) -> None: + """ + Initialize a scenario result. + + Args: + scenario_identifier (ScenarioIdentifier): Identifier for the executed scenario. + objective_target_identifier (Union[Dict[str, Any], TargetIdentifier]): Target identifier. + attack_results (dict[str, List[AttackResult]]): Results grouped by atomic attack name. + objective_scorer_identifier (Union[Dict[str, Any], ScorerIdentifier]): Objective scorer identifier. + scenario_run_state (ScenarioRunState): Current scenario run state. + labels (Optional[dict[str, str]]): Optional labels. + completion_time (Optional[datetime]): Optional completion timestamp. + number_tries (int): Number of run attempts. + id (Optional[uuid.UUID]): Optional scenario result ID. + objective_scorer (Optional[Scorer]): Deprecated scorer object parameter. + + """ from pyrit.common import print_deprecation_message from pyrit.identifiers import ScorerIdentifier, TargetIdentifier @@ -99,7 +116,13 @@ def __init__( self.number_tries = number_tries def get_strategies_used(self) -> List[str]: - """Get the list of strategies used in this scenario.""" + """ + Get the list of strategies used in this scenario. + + Returns: + List[str]: Atomic attack strategy names present in the results. + + """ return list(self.attack_results.keys()) def get_objectives(self, *, atomic_attack_name: Optional[str] = None) -> List[str]: @@ -112,6 +135,7 @@ def get_objectives(self, *, atomic_attack_name: Optional[str] = None) -> List[st Returns: List[str]: Deduplicated list of objectives. + """ objectives: List[str] = [] strategies_to_process: List[List[AttackResult]] @@ -142,6 +166,7 @@ def objective_achieved_rate(self, *, atomic_attack_name: Optional[str] = None) - Returns: int: Success rate as a percentage (0-100). + """ if not atomic_attack_name: # Calculate rate across all atomic attacks @@ -179,6 +204,7 @@ def normalize_scenario_name(scenario_name: str) -> str: Returns: The normalized scenario name suitable for database queries. + """ # Check if it looks like snake_case (contains underscore and is lowercase) if "_" in scenario_name and scenario_name == scenario_name.lower(): @@ -196,6 +222,7 @@ def get_scorer_evaluation_metrics(self) -> Optional["ScorerMetrics"]: Returns: ScorerMetrics: The evaluation metrics object, or None if not found. + """ # import here to avoid circular imports from pyrit.score.scorer_evaluation.scorer_metrics_io import ( diff --git a/pyrit/models/score.py b/pyrit/models/score.py index b66e4b5a1c..4e97230581 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -15,6 +15,8 @@ class Score: + """Represents a normalized score generated by a scorer component.""" + id: uuid.UUID | str # The value the scorer ended up with; e.g. True (if true_false) or 0 (if float_scale) @@ -64,6 +66,26 @@ def __init__( timestamp: Optional[datetime] = None, objective: Optional[str] = None, ): + """ + Initialize a score object. + + Args: + score_value (str): Normalized score value. + score_value_description (str): Human-readable score value description. + score_type (ScoreType): Score type (true_false or float_scale). + score_rationale (str): Rationale for the score. + message_piece_id (str | uuid.UUID): ID of the scored message piece. + id (Optional[uuid.UUID | str]): Optional score ID. + score_category (Optional[List[str]]): Optional score categories. + score_metadata (Optional[Dict[str, Union[str, int, float]]]): Optional metadata. + scorer_class_identifier (Union[ScorerIdentifier, Dict[str, Any]]): Scorer identifier. + timestamp (Optional[datetime]): Optional creation timestamp. + objective (Optional[str]): Optional task objective. + + Raises: + ValueError: If score value or score type is invalid. + + """ # Import at runtime to avoid circular import from pyrit.identifiers import ScorerIdentifier @@ -90,7 +112,7 @@ def __init__( def get_value(self) -> bool | float: """ - Returns the value of the score based on its type. + Return the value of the score based on its type. If the score type is "true_false", it returns True if the score value is "true" (case-insensitive), otherwise it returns False. @@ -101,7 +123,8 @@ def get_value(self) -> bool | float: ValueError: If the score type is unknown. Returns: - The value of the score based on its type. + bool | float: Parsed score value. + """ if self.score_type == "true_false": return self.score_value.lower() == "true" @@ -111,6 +134,17 @@ def get_value(self) -> bool | float: raise ValueError(f"Unknown scorer type: {self.score_type}") def validate(self, scorer_type: str, score_value: str) -> None: + """ + Validate score value against scorer type constraints. + + Args: + scorer_type (str): Scorer type to validate against. + score_value (str): Raw score value. + + Raises: + ValueError: If value is incompatible with scorer type constraints. + + """ if scorer_type == "true_false" and str(score_value).lower() not in ["true", "false"]: raise ValueError(f"True False scorers must have a score value of 'true' or 'false' not {score_value}") elif scorer_type == "float_scale": @@ -122,6 +156,13 @@ def validate(self, scorer_type: str, score_value: str) -> None: raise ValueError(f"Float scale scorers require a numeric score value. Got {score_value}") def to_dict(self) -> Dict[str, Any]: + """ + Convert this score to a dictionary. + + Returns: + Dict[str, Any]: Serialized score payload. + + """ return { "id": str(self.id), "score_value": self.score_value, @@ -137,6 +178,13 @@ def to_dict(self) -> Dict[str, Any]: } def __str__(self) -> str: + """ + Return a concise text representation of this score. + + Returns: + str: Human-readable score summary. + + """ category_str = f": {', '.join(self.score_category) if self.score_category else ''}" if self.scorer_class_identifier: scorer_type = self.scorer_class_identifier.class_name or "Unknown" @@ -167,6 +215,17 @@ class UnvalidatedScore: timestamp: Optional[datetime] = None def to_score(self, *, score_value: str, score_type: ScoreType) -> Score: + """ + Convert this unvalidated score into a validated Score. + + Args: + score_value (str): Normalized score value. + score_type (ScoreType): Score type. + + Returns: + Score: Validated score object. + + """ return Score( id=self.id, score_value=score_value, diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index f2a6bc4697..38433d257e 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -30,18 +30,47 @@ class PartialUndefined(Undefined): + """Jinja undefined value that preserves unresolved placeholders as text.""" + # Return the original placeholder format def __str__(self) -> str: + """ + Render unresolved variable placeholders in template format. + + Returns: + str: Placeholder text or empty string. + + """ return f"{{{{ {self._undefined_name} }}}}" if self._undefined_name else "" def __repr__(self) -> str: + """ + Return the placeholder representation for debugging contexts. + + Returns: + str: Placeholder text or empty string. + + """ return f"{{{{ {self._undefined_name} }}}}" if self._undefined_name else "" def __iter__(self) -> Iterator[object]: - """Return an empty iterator to prevent Jinja from trying to loop over undefined variables.""" + """ + Return an empty iterator to prevent iteration over undefined variables. + + Returns: + Iterator[object]: Empty iterator. + + """ return iter([]) def __bool__(self) -> bool: + """ + Evaluate as truthy to avoid falsey-branch side effects. + + Returns: + bool: Always True. + + """ return True # Ensures it doesn't evaluate to False @@ -106,7 +135,7 @@ def data_type(self) -> PromptDataType: def render_template_value(self, **kwargs: Any) -> str: """ - Renders self.value as a template, applying provided parameters in kwargs. + Render self.value as a template with provided parameters. Args: kwargs:Key-value pairs to replace in the SeedPrompt value. @@ -116,6 +145,7 @@ def render_template_value(self, **kwargs: Any) -> str: Raises: ValueError: If parameters are missing or invalid in the template. + """ template_identifier = self.name or "" @@ -130,7 +160,7 @@ def render_template_value(self, **kwargs: Any) -> str: def render_template_value_silent(self, **kwargs: Any) -> str: """ - Renders self.value as a template, applying provided parameters in kwargs. For parameters in the template + Render self.value as a template with provided parameters. For parameters in the template that are not provided as kwargs here, this function will leave them as is instead of raising an error. Args: @@ -141,6 +171,7 @@ def render_template_value_silent(self, **kwargs: Any) -> str: Raises: ValueError: If parameters are missing or invalid in the template. + """ # Check if the template contains Jinja2 control structures (for loops, if statements, etc.) # If it does, and we don't have all required parameters, don't render it to preserve the structure @@ -168,7 +199,7 @@ def render_template_value_silent(self, **kwargs: Any) -> str: async def set_sha256_value_async(self) -> None: """ - This method computes the SHA256 hash value asynchronously. + Compute the SHA256 hash value asynchronously. It should be called after prompt `value` is serialized to text, as file paths used in the `value` may have changed from local to memory storage paths. diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index f16d17aa1c..62112acb4c 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -41,6 +41,7 @@ def __init__( Raises: ValueError: If seeds is empty. ValueError: If exactly one objective is not provided. + """ super().__init__(seeds=seeds) @@ -52,12 +53,19 @@ def validate(self) -> None: Raises: ValueError: If validation fails. + """ super().validate() self._enforce_exactly_one_objective() def _enforce_exactly_one_objective(self) -> None: - """Ensure exactly one objective is present.""" + """ + Ensure exactly one objective is present. + + Raises: + ValueError: If the group does not contain exactly one SeedObjective. + + """ objective_count = len([s for s in self.seeds if isinstance(s, SeedObjective)]) if objective_count != 1: raise ValueError(f"SeedAttackGroup must have exactly one objective. Found {objective_count}.") @@ -72,6 +80,7 @@ def objective(self) -> SeedObjective: Returns: The SeedObjective for this attack group. + """ obj = self._get_objective() assert obj is not None, "SeedAttackGroup should always have an objective" diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index ec20f3b208..3f78c5c6ac 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -89,6 +89,10 @@ def __init__( added_by: User who added the dataset. seed_type: The type of seeds in this dataset ("prompt", "objective", or "simulated_conversation"). is_objective: Deprecated in 0.13.0. Use seed_type="objective" instead. + + Raises: + ValueError: If seeds are missing or contain invalid/contradictory seed definitions. + """ if not seeds: raise ValueError("SeedDataset cannot be empty.") @@ -201,7 +205,7 @@ def get_values( harm_categories: Optional[Sequence[str]] = None, ) -> Sequence[str]: """ - Extracts and returns a list of prompt values from the dataset. By default, returns all of them. + Extract and return prompt values from the dataset. Args: first (Optional[int]): If provided, values from the first N prompts are included. @@ -211,6 +215,7 @@ def get_values( Returns: Sequence[str]: A list of prompt values. + """ # Filter by harm categories if specified seeds = self.seeds @@ -237,7 +242,7 @@ def get_random_values( self, *, number: PositiveInt, harm_categories: Optional[Sequence[str]] = None ) -> Sequence[str]: """ - Extracts and returns a list of random prompt values from the dataset. + Extract and return random prompt values from the dataset. Args: number (int): The number of random prompt values to return. @@ -246,6 +251,7 @@ def get_random_values( Returns: Sequence[str]: A list of prompt values. + """ prompts = self.get_values(harm_categories=harm_categories) return random.sample(prompts, min(len(prompts), number)) @@ -253,7 +259,17 @@ def get_random_values( @classmethod def from_dict(cls, data: Dict[str, Any]) -> SeedDataset: """ - Builds a SeedDataset by merging top-level defaults into each item in 'seeds'. + Build a SeedDataset by merging top-level defaults into each item in `seeds`. + + Args: + data (Dict[str, Any]): Dataset payload with top-level defaults and seed entries. + + Returns: + SeedDataset: Constructed dataset with merged defaults. + + Raises: + ValueError: If any seed entry includes a pre-set prompt_group_id. + """ # Pop out the seeds section seeds_data = data.pop("seeds", []) @@ -296,16 +312,14 @@ def from_dict(cls, data: Dict[str, Any]) -> SeedDataset: def render_template_value(self, **kwargs: object) -> None: """ - Renders self.value as a template, applying provided parameters in kwargs. + Render seed values as templates using provided parameters. Args: kwargs:Key-value pairs to replace in the SeedDataset value. - Returns: - None - Raises: ValueError: If parameters are missing or invalid in the template. + """ for seed in self.seeds: seed.value = seed.render_template_value(**kwargs) @@ -313,7 +327,7 @@ def render_template_value(self, **kwargs: object) -> None: @staticmethod def _set_seed_group_id_by_alias(seed_prompts: Sequence[dict[str, object]]) -> None: """ - Sets all seed_group_ids based on prompt_group_alias matches. + Set all seed_group_ids based on prompt_group_alias matches. This is important so the prompt_group_alias can be set in yaml to group prompts """ @@ -331,7 +345,7 @@ def _set_seed_group_id_by_alias(seed_prompts: Sequence[dict[str, object]]) -> No @staticmethod def group_seed_prompts_by_prompt_group_id(seeds: Sequence[Seed]) -> Sequence[SeedGroup]: """ - Groups the given list of Seeds by their prompt_group_id and creates + Group the given list of seeds by prompt_group_id and create SeedGroup or SeedAttackGroup instances. For each group, this method first attempts to create a SeedAttackGroup @@ -345,6 +359,7 @@ def group_seed_prompts_by_prompt_group_id(seeds: Sequence[Seed]) -> Sequence[See A list of SeedGroup or SeedAttackGroup objects, with seeds grouped by prompt_group_id. Each group will be ordered by the sequence number of the seeds, if available. + """ # Group seeds by `prompt_group_id` grouped_seeds: Dict[uuid.UUID, list[Seed]] = defaultdict(list) @@ -371,10 +386,24 @@ def group_seed_prompts_by_prompt_group_id(seeds: Sequence[Seed]) -> Sequence[See @property def prompts(self) -> Sequence[SeedPrompt]: + """ + Return all prompt-type seeds. + + Returns: + Sequence[SeedPrompt]: Prompt seeds in this dataset. + + """ return [s for s in self.seeds if isinstance(s, SeedPrompt)] @property def objectives(self) -> Sequence[SeedObjective]: + """ + Return all objective-type seeds. + + Returns: + Sequence[SeedObjective]: Objective seeds in this dataset. + + """ return [s for s in self.seeds if isinstance(s, SeedObjective)] @property @@ -384,8 +413,16 @@ def seed_groups(self) -> Sequence[SeedGroup]: Returns: Sequence[SeedGroup]: A list of SeedGroup objects, with seeds grouped by prompt_group_id. + """ return self.group_seed_prompts_by_prompt_group_id(self.seeds) def __repr__(self) -> str: + """ + Return a concise representation of the dataset. + + Returns: + str: Dataset summary string. + + """ return f"" diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 935438a627..31a5216889 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -61,6 +61,7 @@ def __init__( ValueError: If seeds is empty. ValueError: If multiple objectives are provided. ValueError: If SeedPrompt sequences overlap with SeedSimulatedConversation range. + """ if not seeds: raise ValueError("SeedGroup cannot be empty.") @@ -129,6 +130,7 @@ def validate(self) -> None: Raises: ValueError: If validation fails. + """ if not self.seeds: raise ValueError("SeedGroup cannot be empty.") @@ -139,12 +141,24 @@ def validate(self) -> None: self._enforce_no_sequence_overlap_with_simulated() def _enforce_max_one_objective(self) -> None: - """Ensure at most one objective is present.""" + """ + Ensure at most one objective is present. + + Raises: + ValueError: If more than one SeedObjective exists. + + """ if len([s for s in self.seeds if isinstance(s, SeedObjective)]) > 1: raise ValueError("SeedGroup can only have one objective.") def _enforce_max_one_simulated_conversation(self) -> None: - """Ensure at most one simulated conversation is present.""" + """ + Ensure at most one simulated conversation is present. + + Raises: + ValueError: If more than one SeedSimulatedConversation exists. + + """ if len([s for s in self.seeds if isinstance(s, SeedSimulatedConversation)]) > 1: raise ValueError("SeedGroup can only have one simulated conversation.") @@ -156,6 +170,7 @@ def _enforce_consistent_group_id(self) -> None: Raises: ValueError: If multiple different group IDs exist. + """ existing_group_ids = {seed.prompt_group_id for seed in self.seeds if seed.prompt_group_id is not None} @@ -177,6 +192,7 @@ def _enforce_consistent_role(self) -> None: Raises: ValueError: If roles are inconsistent within a sequence. ValueError: If no roles are set in a multi-sequence group. + """ grouped_prompts = defaultdict(list) for prompt in self.prompts: @@ -206,6 +222,7 @@ def _enforce_no_sequence_overlap_with_simulated(self) -> None: Raises: ValueError: If any SeedPrompt sequence overlaps with the simulated range. + """ simulated_config = self._get_simulated_conversation() if simulated_config is None: @@ -226,14 +243,26 @@ def _enforce_no_sequence_overlap_with_simulated(self) -> None: # ========================================================================= def _get_objective(self) -> Optional[SeedObjective]: - """Get the objective seed if present.""" + """ + Get the objective seed if present. + + Returns: + Optional[SeedObjective]: Objective seed when available; otherwise None. + + """ for seed in self.seeds: if isinstance(seed, SeedObjective): return seed return None def _get_simulated_conversation(self) -> Optional[SeedSimulatedConversation]: - """Get the simulated conversation seed if present.""" + """ + Get the simulated conversation seed if present. + + Returns: + Optional[SeedSimulatedConversation]: Simulated conversation seed when available; otherwise None. + + """ for seed in self.seeds: if isinstance(seed, SeedSimulatedConversation): return seed @@ -256,6 +285,7 @@ def harm_categories(self) -> List[str]: Returns: List of harm categories with duplicates removed. + """ categories: List[str] = [] for seed in self.seeds: @@ -290,6 +320,7 @@ def prepended_conversation(self) -> Optional[List[Message]]: Returns: Messages for conversation history, or None if empty. + """ if not self.prompts: return None @@ -318,6 +349,7 @@ def next_message(self) -> Optional[Message]: Returns: Message for the current/last turn if user role, or None otherwise. + """ if not self.prompts: return None @@ -344,6 +376,7 @@ def user_messages(self) -> List[Message]: Returns: All user messages in sequence order, or empty list if no prompts. + """ if not self.prompts: return [] @@ -356,6 +389,7 @@ def _get_last_sequence_role(self) -> Optional[str]: Returns: The role of the last sequence, or None if no prompts exist. + """ if not self.prompts: return None @@ -377,6 +411,7 @@ def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> List[Message]: Returns: Messages created from the prompts. + """ sequence_groups = defaultdict(list) for prompt in prompts: @@ -413,27 +448,53 @@ def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> List[Message]: def render_template_value(self, **kwargs: Any) -> None: """ - Renders seed values as templates with provided parameters. + Render seed values as templates with provided parameters. Args: kwargs: Key-value pairs to replace in seed values. + """ for seed in self.seeds: seed.value = seed.render_template_value(**kwargs) def is_single_turn(self) -> bool: - """Check if this is a single-turn group (single request without objective).""" + """ + Check if this is a single-turn group (single request without objective). + + Returns: + bool: True when the group is a single request and has no objective. + + """ return self.is_single_request() and not self.objective def is_single_request(self) -> bool: - """Check if all prompts are in a single sequence.""" + """ + Check if all prompts are in a single sequence. + + Returns: + bool: True when all prompts share one sequence number. + + """ unique_sequences = {prompt.sequence for prompt in self.prompts} return len(unique_sequences) == 1 def is_single_part_single_text_request(self) -> bool: - """Check if this is a single text prompt.""" + """ + Check if this is a single text prompt. + + Returns: + bool: True when there is exactly one prompt and it is text. + + """ return len(self.prompts) == 1 and self.prompts[0].data_type == "text" def __repr__(self) -> str: + """ + Return a concise representation of the seed group. + + Returns: + str: Seed group summary string. + + """ sim_info = " (simulated)" if self.has_simulated_conversation else "" return f"" diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index fe4ae90c6b..4fa1fdafbd 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -44,5 +44,6 @@ def from_yaml_with_required_parameters( Returns: SeedObjective: The loaded and validated seed of the specific subclass type. + """ return cls.from_yaml_file(template_path) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 8d478f7faa..048098941a 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -46,7 +46,13 @@ class SeedPrompt(Seed): parameters: Optional[Sequence[str]] = field(default_factory=lambda: []) def __post_init__(self) -> None: - """Post-initialization to render the template to replace existing values.""" + """ + Render template placeholders and infer data_type after initialization. + + Raises: + ValueError: If file-based data type cannot be inferred from extension. + + """ self.value = self.render_template_value_silent(**PATHS_DICT) if not self.data_type: @@ -68,7 +74,7 @@ def __post_init__(self) -> None: def set_encoding_metadata(self) -> None: """ - This method sets the encoding data for the prompt within metadata dictionary. For images, this is just the + Set encoding metadata for the prompt within metadata dictionary. For images, this is just the file format. For audio and video, this also includes bitrate (kBits/s as int), samplerate (samples/second as int), bitdepth (as int), filesize (bytes as int), and duration (seconds as int) if the file type is supported by TinyTag. Example supported file types include: MP3, MP4, M4A, and WAV. @@ -122,6 +128,7 @@ def from_yaml_with_required_parameters( Raises: ValueError: If the template doesn't contain all required parameters. + """ sp = cls.from_yaml_file(template_path) @@ -152,6 +159,7 @@ def from_messages( Returns: List of SeedPrompts with incrementing sequence numbers per message. + """ seed_prompts: list[SeedPrompt] = [] current_sequence = starting_sequence diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index 092708fb66..8539ceb5d5 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -63,6 +63,7 @@ class SeedSimulatedConversation(Seed): an additional user message after the simulated conversation. If provided, a single LLM call generates a final user message that attempts to get the target to fulfill the objective in their next response. + """ def __init__( @@ -95,6 +96,10 @@ def __init__( Defaults to 0. pyrit_version: PyRIT version for reproducibility tracking. Defaults to current version. **kwargs: Additional arguments passed to the Seed base class. + + Raises: + ValueError: If num_turns is not positive or sequence is negative. + """ # Apply default for simulated target system prompt if not provided if simulated_target_system_prompt_path is None: @@ -119,7 +124,13 @@ def __init__( super().__init__(value=self._compute_value(), **kwargs) def _compute_value(self) -> str: - """Compute the value field as JSON serialization of config.""" + """ + Compute the value field as JSON serialization of config. + + Returns: + str: Deterministic JSON representation of this configuration. + + """ config = { "num_turns": self.num_turns, "sequence": self.sequence, @@ -147,6 +158,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "SeedSimulatedConversation": Returns: A new SeedSimulatedConversation instance. + + Raises: + ValueError: If required configuration fields are missing. + """ adversarial_path = data.get("adversarial_chat_system_prompt_path") if not adversarial_path: @@ -180,6 +195,7 @@ def from_yaml_with_required_parameters( Raises: ValueError: If required parameters are missing. + """ instance = cls.from_yaml_file(template_path) @@ -197,6 +213,7 @@ def get_identifier(self) -> Dict[str, Any]: Returns: Dictionary with configuration details. + """ return { "__type__": "SeedSimulatedConversation", @@ -216,6 +233,7 @@ def compute_hash(self) -> str: Returns: A SHA256 hash string representing the configuration. + """ identifier = self.get_identifier() config_json = json.dumps(identifier, sort_keys=True, separators=(",", ":")) @@ -245,6 +263,7 @@ def load_simulated_target_system_prompt( Raises: ValueError: If the template doesn't have required parameters. + """ if simulated_target_system_prompt_path is None: return None @@ -271,11 +290,19 @@ def sequence_range(self) -> range: Returns: A range object representing the sequence numbers. + """ message_count = self.num_turns * 2 + (1 if self.next_message_system_prompt_path else 0) return range(self.sequence, self.sequence + message_count) def __repr__(self) -> str: + """ + Return a concise representation of this simulated conversation seed. + + Returns: + str: Simulated conversation summary string. + + """ has_next_msg = self.next_message_system_prompt_path is not None return ( f" bytes: Returns: bytes: The content of the file. + """ path = self._convert_to_path(path) async with aiofiles.open(path, "rb") as file: @@ -91,6 +92,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: Args: path (Path): The path to the file. data (bytes): The content to write to the file. + """ path = self._convert_to_path(path) async with aiofiles.open(path, "wb") as file: @@ -98,26 +100,28 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: async def path_exists(self, path: Union[Path, str]) -> bool: """ - Checks if a path exists on the local disk. + Check whether a path exists on the local disk. Args: path (Path): The path to check. Returns: bool: True if the path exists, False otherwise. + """ path = self._convert_to_path(path) return os.path.exists(path) async def is_file(self, path: Union[Path, str]) -> bool: """ - Checks if the given path is a file (not a directory). + Check whether the given path is a file (not a directory). Args: path (Path): The path to check. Returns: bool: True if the path is a file, False otherwise. + """ path = self._convert_to_path(path) return os.path.isfile(path) @@ -128,6 +132,7 @@ async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: Args: path (Path): The directory path to create. + """ directory_path = self._convert_to_path(path) if not directory_path.exists(): @@ -135,7 +140,14 @@ async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: def _convert_to_path(self, path: Union[Path, str]) -> Path: """ - Converts the path to a Path object if it's a string. + Convert an input path to a Path object. + + Args: + path (Union[Path, str]): Input path value. + + Returns: + Path: Normalized Path instance. + """ path = Path(path) if isinstance(path, str) else path return path @@ -153,6 +165,18 @@ def __init__( sas_token: Optional[str] = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, ) -> None: + """ + Initialize an Azure Blob Storage I/O adapter. + + Args: + container_url (Optional[str]): Azure Blob container URL. + sas_token (Optional[str]): Optional SAS token. + blob_content_type (SupportedContentType): Blob content type for uploads. + + Raises: + ValueError: If container_url is missing. + + """ self._blob_content_type: str = blob_content_type.value if not container_url: raise ValueError("Invalid Azure Storage Account Container URL.") @@ -163,7 +187,9 @@ def __init__( async def _create_container_client_async(self) -> None: """ - Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the + Create an asynchronous ContainerClient for Azure Storage. + + If a SAS token is provided via the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. """ @@ -184,6 +210,7 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st file_name (str): File name to assign to uploaded blob. data (bytes): Byte representation of content to upload to container. content_type (str): Content type to upload. + """ content_settings = ContentSettings(content_type=f"{content_type}") # type: ignore[no-untyped-call, unused-ignore] logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) @@ -209,7 +236,19 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st raise def parse_blob_url(self, file_path: str) -> tuple[str, str]: - """Parses the blob URL to extract the container name and blob name.""" + """ + Parse a blob URL to extract the container and blob name. + + Args: + file_path (str): Full blob URL. + + Returns: + tuple[str, str]: Container name and blob name. + + Raises: + ValueError: If file_path is not a valid blob URL. + + """ parsed_url = urlparse(file_path) if parsed_url.scheme and parsed_url.netloc: container_name = parsed_url.path.split("/")[1] @@ -243,6 +282,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: await read_file("https://account.blob.core.windows.net/container/dir2/1726627689003831.png") # Or using a relative path: file_content = await read_file("dir1/dir2/1726627689003831.png") + """ if not self._client_async: await self._create_container_client_async() @@ -267,11 +307,12 @@ async def read_file(self, path: Union[Path, str]) -> bytes: async def write_file(self, path: Union[Path, str], data: bytes) -> None: """ - Writes data to Azure Blob Storage at the specified path. + Write data to Azure Blob Storage at the specified path. Args: path (str): The full Azure Blob Storage URL data (bytes): The data to write. + """ if not self._client_async: await self._create_container_client_async() @@ -286,7 +327,16 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: self._client_async = None async def path_exists(self, path: Union[Path, str]) -> bool: - """Check if a given path exists in the Azure Blob Storage container.""" + """ + Check whether a given path exists in the Azure Blob Storage container. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the path exists. + + """ if not self._client_async: await self._create_container_client_async() try: @@ -301,7 +351,16 @@ async def path_exists(self, path: Union[Path, str]) -> bool: self._client_async = None async def is_file(self, path: Union[Path, str]) -> bool: - """Check if the path refers to a file (blob) in Azure Blob Storage.""" + """ + Check whether the path refers to a file (blob) in Azure Blob Storage. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the blob exists and has non-zero content size. + + """ if not self._client_async: await self._create_container_client_async() try: @@ -316,6 +375,13 @@ async def is_file(self, path: Union[Path, str]) -> bool: self._client_async = None async def create_directory_if_not_exists(self, directory_path: Union[Path, str]) -> None: + """ + Log a no-op directory creation for Azure Blob Storage. + + Args: + directory_path (Union[Path, str]): Requested directory path. + + """ logger.info( f"Directory creation is handled automatically during upload operations in Azure Blob Storage. " f"Directory path: {directory_path}" diff --git a/pyrit/models/strategy_result.py b/pyrit/models/strategy_result.py index e4b90cfab1..8de784e863 100644 --- a/pyrit/models/strategy_result.py +++ b/pyrit/models/strategy_result.py @@ -21,5 +21,6 @@ def duplicate(self: StrategyResultT) -> StrategyResultT: Returns: StrategyResult: A deep copy of the result. + """ return deepcopy(self) From ee9e4ad386b6f3e9a046fff9d46fe9e52f658513 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 23 Feb 2026 20:36:36 -0800 Subject: [PATCH 2/3] Address review comments on docstrings - Convert Parameters/---------- sections to Google-style Attributes: format - Add type annotations in parentheses to Args in embeddings.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/chat_message.py | 4 +--- pyrit/models/embeddings.py | 4 ++-- pyrit/models/message.py | 4 +--- pyrit/models/question_answering.py | 12 +++--------- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index bd05fd83c3..6cacb46284 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -98,13 +98,11 @@ class ChatMessagesDataset(BaseModel): """ Represents a dataset of chat messages. - Parameters - ---------- + Attributes: model_config (ConfigDict): The model configuration. name (str): The name of the dataset. description (str): The description of the dataset. list_of_chat_messages (list[list[ChatMessage]]): A list of chat messages. - """ model_config = ConfigDict(extra="forbid") diff --git a/pyrit/models/embeddings.py b/pyrit/models/embeddings.py index 0fb7c36529..a71d636b85 100644 --- a/pyrit/models/embeddings.py +++ b/pyrit/models/embeddings.py @@ -41,7 +41,7 @@ def save_to_file(self, directory_path: Path) -> str: Save the embedding response to disk and return the path of the new file. Args: - directory_path: The path to save the file to. + directory_path (Path): The path to save the file to. Returns: str: The full path to the file that was saved. @@ -59,7 +59,7 @@ def load_from_file(file_path: Path) -> EmbeddingResponse: Load the embedding response from disk. Args: - file_path: The path to load the file from. + file_path (Path): The path to load the file from. Returns: EmbeddingResponse: The loaded embedding response. diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 0ba939e5cd..565c7bd052 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -20,10 +20,8 @@ class Message: This is a single request to a target. It can contain multiple message pieces. - Parameters - ---------- + Attributes: message_pieces (Sequence[MessagePiece]): The list of message pieces. - """ def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False) -> None: diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index 610eb7ae37..59878b6b64 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -12,11 +12,9 @@ class QuestionChoice(BaseModel): """ Represents a choice for a question. - Parameters - ---------- + Attributes: index (int): The index of the choice. text (str): The text of the choice. - """ model_config = ConfigDict(extra="forbid") @@ -28,8 +26,7 @@ class QuestionAnsweringEntry(BaseModel): """ Represents a question model. - Parameters - ---------- + Attributes: question (str): The question text. answer_type (Literal["int", "float", "str", "bool"]): The type of the answer. `int` for integer answers (e.g., when the answer is an index of the correct option in a multiple-choice @@ -39,7 +36,6 @@ class QuestionAnsweringEntry(BaseModel): `bool` for boolean answers. correct_answer (Union[int, str, float]): The correct answer. choices (list[QuestionChoice]): The list of choices for the question. - """ model_config = ConfigDict(extra="forbid") @@ -84,8 +80,7 @@ class QuestionAnsweringDataset(BaseModel): """ Represents a dataset for question answering. - Parameters - ---------- + Attributes: name (str): The name of the dataset. version (str): The version of the dataset. description (str): A description of the dataset. @@ -93,7 +88,6 @@ class QuestionAnsweringDataset(BaseModel): group (str): The group associated with the dataset. source (str): The source of the dataset. questions (list[QuestionAnsweringEntry]): A list of question models. - """ model_config = ConfigDict(extra="forbid") From 689870783f61a721dbe9174203090aa9af59e857 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 23 Feb 2026 20:59:39 -0800 Subject: [PATCH 3/3] Remove Attributes: sections from data class docstrings to fix Sphinx duplicate object warnings Sphinx autosummary auto-documents class attributes from type annotations, so having Attributes: in the docstring creates duplicate entries that cause the docs build to fail. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/chat_message.py | 6 ------ pyrit/models/message.py | 3 --- pyrit/models/question_answering.py | 24 ------------------------ 3 files changed, 33 deletions(-) diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index 6cacb46284..6dea332aca 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -97,12 +97,6 @@ def __init__(self, **data: Any) -> None: class ChatMessagesDataset(BaseModel): """ Represents a dataset of chat messages. - - Attributes: - model_config (ConfigDict): The model configuration. - name (str): The name of the dataset. - description (str): The description of the dataset. - list_of_chat_messages (list[list[ChatMessage]]): A list of chat messages. """ model_config = ConfigDict(extra="forbid") diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 565c7bd052..340e052466 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -19,9 +19,6 @@ class Message: Represents a message in a conversation, for example a prompt or a response to a prompt. This is a single request to a target. It can contain multiple message pieces. - - Attributes: - message_pieces (Sequence[MessagePiece]): The list of message pieces. """ def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False) -> None: diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index 59878b6b64..1d526a2dad 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -11,10 +11,6 @@ class QuestionChoice(BaseModel): """ Represents a choice for a question. - - Attributes: - index (int): The index of the choice. - text (str): The text of the choice. """ model_config = ConfigDict(extra="forbid") @@ -25,17 +21,6 @@ class QuestionChoice(BaseModel): class QuestionAnsweringEntry(BaseModel): """ Represents a question model. - - Attributes: - question (str): The question text. - answer_type (Literal["int", "float", "str", "bool"]): The type of the answer. - `int` for integer answers (e.g., when the answer is an index of the correct option in a multiple-choice - question). - `float` for answers that are floating-point numbers. - `str` for text-based answers. - `bool` for boolean answers. - correct_answer (Union[int, str, float]): The correct answer. - choices (list[QuestionChoice]): The list of choices for the question. """ model_config = ConfigDict(extra="forbid") @@ -79,15 +64,6 @@ def __hash__(self) -> int: class QuestionAnsweringDataset(BaseModel): """ Represents a dataset for question answering. - - Attributes: - name (str): The name of the dataset. - version (str): The version of the dataset. - description (str): A description of the dataset. - author (str): The author of the dataset. - group (str): The group associated with the dataset. - source (str): The source of the dataset. - questions (list[QuestionAnsweringEntry]): A list of question models. """ model_config = ConfigDict(extra="forbid")