diff --git a/build_scripts/check_links.py b/build_scripts/check_links.py index 7d99221313..6c64077370 100644 --- a/build_scripts/check_links.py +++ b/build_scripts/check_links.py @@ -34,7 +34,7 @@ def extract_urls(file_path): - with open(file_path, "r", encoding="utf-8") as file: + with open(file_path, encoding="utf-8") as file: content = file.read() matches = URL_PATTERN.findall(content) # Flatten the list of tuples and filter out empty strings diff --git a/build_scripts/generate_rss.py b/build_scripts/generate_rss.py index bf418b8bfe..ba17c9cf27 100644 --- a/build_scripts/generate_rss.py +++ b/build_scripts/generate_rss.py @@ -60,7 +60,7 @@ def handle_data(self, data): fe.guid(f"https://azure.github.io/PyRIT/blog/{file.name}") # Extract title and description from HTML content - with open(file, "r", encoding="utf-8") as f: + with open(file, encoding="utf-8") as f: parser = BlogEntryParser() parser.feed(f.read()) fe.title(parser.title) diff --git a/build_scripts/remove_notebook_headers.py b/build_scripts/remove_notebook_headers.py index c3f0ad5065..824c1851a4 100644 --- a/build_scripts/remove_notebook_headers.py +++ b/build_scripts/remove_notebook_headers.py @@ -10,7 +10,7 @@ def remove_kernelspec_from_ipynb_files(file_path: str): if file_path.endswith(".ipynb"): # Iterate through all .ipynb files in the specified file - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: content = json.load(f) # Remove the "kernelspec" metadata section if it exists if "metadata" in content and "kernelspec" in content["metadata"]: diff --git a/build_scripts/validate_jupyter_book.py b/build_scripts/validate_jupyter_book.py index 1b067207ef..50dfbfedc8 100644 --- a/build_scripts/validate_jupyter_book.py +++ b/build_scripts/validate_jupyter_book.py @@ -20,19 +20,18 @@ import re import sys from pathlib import Path -from typing import List, Set, Tuple import yaml -def parse_api_rst(api_rst_path: Path) -> List[Tuple[str, List[str]]]: +def parse_api_rst(api_rst_path: Path) -> list[tuple[str, list[str]]]: """ Parse api.rst file to extract module names and their autosummary members. Returns: List of tuples: (module_name, [member_names]) """ - with open(api_rst_path, "r", encoding="utf-8") as f: + with open(api_rst_path, encoding="utf-8") as f: content = f.read() modules = [] @@ -61,7 +60,7 @@ def parse_api_rst(api_rst_path: Path) -> List[Tuple[str, List[str]]]: return modules -def validate_api_rst_modules(modules: List[Tuple[str, List[str]]], repo_root: Path) -> List[str]: +def validate_api_rst_modules(modules: list[tuple[str, list[str]]], repo_root: Path) -> list[str]: """ Validate that modules exist and autosummary members are defined. @@ -110,7 +109,7 @@ def validate_api_rst_modules(modules: List[Tuple[str, List[str]]], repo_root: Pa if module_file: # Read the source file and check for member definitions try: - with open(module_file, "r", encoding="utf-8") as f: + with open(module_file, encoding="utf-8") as f: source_content = f.read() for member in members: @@ -141,14 +140,14 @@ def validate_api_rst_modules(modules: List[Tuple[str, List[str]]], repo_root: Pa return errors -def parse_toc_yml(toc_path: Path) -> Set[str]: +def parse_toc_yml(toc_path: Path) -> set[str]: """ Parse _toc.yml file to extract all file references. Returns: Set of file paths (relative to doc/ directory, without extensions) """ - with open(toc_path, "r", encoding="utf-8") as f: + with open(toc_path, encoding="utf-8") as f: toc_data = yaml.safe_load(f) files = set() @@ -169,7 +168,7 @@ def extract_files(node): return files -def validate_toc_yml_files(toc_files: Set[str], doc_root: Path) -> List[str]: +def validate_toc_yml_files(toc_files: set[str], doc_root: Path) -> list[str]: """ Validate that all files referenced in _toc.yml exist. @@ -198,7 +197,7 @@ def validate_toc_yml_files(toc_files: Set[str], doc_root: Path) -> List[str]: return errors -def find_orphaned_doc_files(toc_files: Set[str], doc_root: Path) -> List[str]: +def find_orphaned_doc_files(toc_files: set[str], doc_root: Path) -> list[str]: """ Find documentation files that exist but are not referenced in _toc.yml. diff --git a/doc/code/executor/workflow/1_xpia_website.ipynb b/doc/code/executor/workflow/1_xpia_website.ipynb index b1416af856..ad027059bf 100644 --- a/doc/code/executor/workflow/1_xpia_website.ipynb +++ b/doc/code/executor/workflow/1_xpia_website.ipynb @@ -30,7 +30,7 @@ "from pyrit.models import Message, MessagePiece\n", "\n", "# Read basic HTML file with template slot for the XPIA.\n", - "with open(Path().cwd() / \"example\" / \"index.html\", \"r\") as f:\n", + "with open(Path().cwd() / \"example\" / \"index.html\") as f:\n", " html_template = f.read()\n", "jailbreak_template = TextJailBreak(string_template=html_template)\n", "\n", diff --git a/doc/code/executor/workflow/1_xpia_website.py b/doc/code/executor/workflow/1_xpia_website.py index 13ce18d486..aae0d6df03 100644 --- a/doc/code/executor/workflow/1_xpia_website.py +++ b/doc/code/executor/workflow/1_xpia_website.py @@ -25,7 +25,7 @@ from pyrit.models import Message, MessagePiece # Read basic HTML file with template slot for the XPIA. -with open(Path().cwd() / "example" / "index.html", "r") as f: +with open(Path().cwd() / "example" / "index.html") as f: html_template = f.read() jailbreak_template = TextJailBreak(string_template=html_template) diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index 13c22ec6e7..0c21846f0f 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -96,7 +96,7 @@ } ], "source": [ - "from typing import List, Optional, Type\n", + "from typing import Optional\n", "\n", "from pyrit.common import apply_defaults\n", "from pyrit.executor.attack import AttackScoringConfig, PromptSendingAttack\n", @@ -124,7 +124,7 @@ "\n", " # A strategy defintion helps callers define how to run your scenario (e.g. from the front_end)\n", " @classmethod\n", - " def get_strategy_class(cls) -> Type[ScenarioStrategy]:\n", + " def get_strategy_class(cls) -> type[ScenarioStrategy]:\n", " return MyStrategy\n", "\n", " @classmethod\n", @@ -155,7 +155,7 @@ " scenario_result_id=scenario_result_id,\n", " )\n", "\n", - " async def _get_atomic_attacks_async(self) -> List[AtomicAttack]:\n", + " async def _get_atomic_attacks_async(self) -> list[AtomicAttack]:\n", " \"\"\"\n", " Build atomic attacks based on selected strategies.\n", "\n", diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index c6ba6d8ad4..992b95aeaa 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -84,7 +84,7 @@ # # ### Example Structure # %% -from typing import List, Optional, Type +from typing import Optional from pyrit.common import apply_defaults from pyrit.executor.attack import AttackScoringConfig, PromptSendingAttack @@ -112,7 +112,7 @@ class MyScenario(Scenario): # A strategy defintion helps callers define how to run your scenario (e.g. from the front_end) @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: + def get_strategy_class(cls) -> type[ScenarioStrategy]: return MyStrategy @classmethod @@ -143,7 +143,7 @@ def __init__( scenario_result_id=scenario_result_id, ) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Build atomic attacks based on selected strategies. diff --git a/doc/code/targets/11_message_normalizer.ipynb b/doc/code/targets/11_message_normalizer.ipynb index cf6f6a26a1..1377311dd8 100644 --- a/doc/code/targets/11_message_normalizer.ipynb +++ b/doc/code/targets/11_message_normalizer.ipynb @@ -491,8 +491,6 @@ } ], "source": [ - "from typing import List\n", - "\n", "from pyrit.message_normalizer import MessageStringNormalizer\n", "from pyrit.models import Message\n", "\n", @@ -500,7 +498,7 @@ "class SimpleMarkdownNormalizer(MessageStringNormalizer):\n", " \"\"\"Custom normalizer that formats messages as Markdown.\"\"\"\n", "\n", - " async def normalize_string_async(self, messages: List[Message]) -> str:\n", + " async def normalize_string_async(self, messages: list[Message]) -> str:\n", " lines = []\n", " for msg in messages:\n", " piece = msg.get_piece()\n", diff --git a/doc/code/targets/11_message_normalizer.py b/doc/code/targets/11_message_normalizer.py index 775c45b346..b5e2c9181c 100644 --- a/doc/code/targets/11_message_normalizer.py +++ b/doc/code/targets/11_message_normalizer.py @@ -203,7 +203,6 @@ # You can create custom normalizers by extending the base classes. # %% -from typing import List from pyrit.message_normalizer import MessageStringNormalizer from pyrit.models import Message @@ -212,7 +211,7 @@ class SimpleMarkdownNormalizer(MessageStringNormalizer): """Custom normalizer that formats messages as Markdown.""" - async def normalize_string_async(self, messages: List[Message]) -> str: + async def normalize_string_async(self, messages: list[Message]) -> str: lines = [] for msg in messages: piece = msg.get_piece() diff --git a/doc/code/targets/6_custom_targets.ipynb b/doc/code/targets/6_custom_targets.ipynb index b37f2a8010..3d2880ea4a 100644 --- a/doc/code/targets/6_custom_targets.ipynb +++ b/doc/code/targets/6_custom_targets.ipynb @@ -254,8 +254,6 @@ } ], "source": [ - "from typing import List\n", - "\n", "from pyrit.executor.attack import (\n", " AttackConverterConfig,\n", " AttackScoringConfig,\n", @@ -281,7 +279,7 @@ "\n", "aoai_target = OpenAIChatTarget()\n", "\n", - "converters: List[PromptConverterConfiguration] = PromptConverterConfiguration.from_converters(\n", + "converters: list[PromptConverterConfiguration] = PromptConverterConfiguration.from_converters(\n", " converters=[RandomCapitalLettersConverter(percentage=25)]\n", ")\n", "\n", diff --git a/doc/code/targets/6_custom_targets.py b/doc/code/targets/6_custom_targets.py index c588539988..d8352aa460 100644 --- a/doc/code/targets/6_custom_targets.py +++ b/doc/code/targets/6_custom_targets.py @@ -103,7 +103,6 @@ # Below is an example of using PromptSendingAttack, which allows the use of all our converters. For example, you could use this to utilize all the built-in jailbreaks, base64 encode them, use variations, different languages, etc. # %% -from typing import List from pyrit.executor.attack import ( AttackConverterConfig, @@ -130,7 +129,7 @@ aoai_target = OpenAIChatTarget() -converters: List[PromptConverterConfiguration] = PromptConverterConfiguration.from_converters( +converters: list[PromptConverterConfiguration] = PromptConverterConfiguration.from_converters( converters=[RandomCapitalLettersConverter(percentage=25)] ) diff --git a/doc/cookbooks/5_psychosocial_harms.ipynb b/doc/cookbooks/5_psychosocial_harms.ipynb index 636a1b11ac..feb63cb577 100644 --- a/doc/cookbooks/5_psychosocial_harms.ipynb +++ b/doc/cookbooks/5_psychosocial_harms.ipynb @@ -195,7 +195,7 @@ "print(\"Attack Technique using Escalation for a user in imminent crisis:\")\n", "attack_strategy_path = pathlib.Path(DATASETS_PATH) / \"executors\" / \"crescendo\" / \"escalation_crisis.yaml\"\n", "\n", - "with open(attack_strategy_path, \"r\") as file:\n", + "with open(attack_strategy_path) as file:\n", " print(file.read())" ] }, diff --git a/doc/cookbooks/5_psychosocial_harms.py b/doc/cookbooks/5_psychosocial_harms.py index 6463a42822..5c0520557e 100644 --- a/doc/cookbooks/5_psychosocial_harms.py +++ b/doc/cookbooks/5_psychosocial_harms.py @@ -77,7 +77,7 @@ print("Attack Technique using Escalation for a user in imminent crisis:") attack_strategy_path = pathlib.Path(DATASETS_PATH) / "executors" / "crescendo" / "escalation_crisis.yaml" -with open(attack_strategy_path, "r") as file: +with open(attack_strategy_path) as file: print(file.read()) # %% [markdown] diff --git a/doc/deployment/deploy_hf_model_aml.ipynb b/doc/deployment/deploy_hf_model_aml.ipynb index 08db471277..84c09152d7 100644 --- a/doc/deployment/deploy_hf_model_aml.ipynb +++ b/doc/deployment/deploy_hf_model_aml.ipynb @@ -186,16 +186,12 @@ "if check_model_version_exists(workspace_ml_client, model_to_deploy, model_version):\n", " print(\"Model found in the Azure ML workspace model registry.\")\n", " model = workspace_ml_client.models.get(model_to_deploy, model_version)\n", - " print(\n", - " \"\\n\\nUsing model name: {0}, version: {1}, id: {2} for inferencing\".format(model.name, model.version, model.id)\n", - " )\n", + " print(f\"\\n\\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for inferencing\")\n", "# Check if the Hugging Face model exists in the Azure ML model catalog registry\n", "elif check_model_version_exists(registry_ml_client, model_to_deploy, model_version):\n", " print(\"Model found in the Azure ML model catalog registry.\")\n", " model = registry_ml_client.models.get(model_to_deploy, model_version)\n", - " print(\n", - " \"\\n\\nUsing model name: {0}, version: {1}, id: {2} for inferencing\".format(model.name, model.version, model.id)\n", - " )\n", + " print(f\"\\n\\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for inferencing\")\n", "else:\n", " raise ValueError(\n", " f\"Model {model_to_deploy} not found in any registry. Please run the notebook (download_and_register_hf_model_aml.ipynb) to download and register Hugging Face model to Azure ML workspace model registry.\"\n", diff --git a/doc/deployment/deploy_hf_model_aml.py b/doc/deployment/deploy_hf_model_aml.py index 8e6029c234..d21c95aee7 100644 --- a/doc/deployment/deploy_hf_model_aml.py +++ b/doc/deployment/deploy_hf_model_aml.py @@ -154,16 +154,12 @@ def check_model_version_exists(client, model_name, version) -> bool: if check_model_version_exists(workspace_ml_client, model_to_deploy, model_version): print("Model found in the Azure ML workspace model registry.") model = workspace_ml_client.models.get(model_to_deploy, model_version) - print( - "\n\nUsing model name: {0}, version: {1}, id: {2} for inferencing".format(model.name, model.version, model.id) - ) + print(f"\n\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for inferencing") # Check if the Hugging Face model exists in the Azure ML model catalog registry elif check_model_version_exists(registry_ml_client, model_to_deploy, model_version): print("Model found in the Azure ML model catalog registry.") model = registry_ml_client.models.get(model_to_deploy, model_version) - print( - "\n\nUsing model name: {0}, version: {1}, id: {2} for inferencing".format(model.name, model.version, model.id) - ) + print(f"\n\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for inferencing") else: raise ValueError( f"Model {model_to_deploy} not found in any registry. Please run the notebook (download_and_register_hf_model_aml.ipynb) to download and register Hugging Face model to Azure ML workspace model registry." diff --git a/doc/generate_docs/pct_to_ipynb.py b/doc/generate_docs/pct_to_ipynb.py index 3a1b9d9e04..b6bd1fbf07 100644 --- a/doc/generate_docs/pct_to_ipynb.py +++ b/doc/generate_docs/pct_to_ipynb.py @@ -52,7 +52,7 @@ def main(): cache_file = os.path.join(cache_dir, f"pct_to_ipynb_{args.run_id}.cache") processed_files = set() if os.path.isfile(cache_file): - with open(cache_file, "r") as f: + with open(cache_file) as f: for file_path in f: processed_files.add(file_path.strip()) diff --git a/frontend/dev.py b/frontend/dev.py index f699ccf84e..1df8b02090 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -42,7 +42,7 @@ def sync_version(): # Read package.json package_json_path = FRONTEND_DIR / "package.json" - with open(package_json_path, "r") as f: + with open(package_json_path) as f: package_data = json.load(f) # Update version if different diff --git a/pyproject.toml b/pyproject.toml index e1fb01ac81..10d1751d93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,6 +253,7 @@ select = [ "F401", # unused-import "I", # isort "RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret + "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up "W", # https://docs.astral.sh/ruff/rules/#pycodestyle-w ] ignore = [ @@ -262,6 +263,8 @@ ignore = [ "D212", # Multi-line docstring summary should start at the first line "D301", # Use r""" if any backslashes in a docstring "DOC502", # Raised exception is not explicitly raised + "UP007", # non-pep604-annotation-union (keep Union[X, Y] syntax) + "UP045", # non-pep604-annotation-optional (keep Optional[X] syntax) ] extend-select = [ "D204", # 1 blank line required after class docstring diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index a6e260af39..cc5f58fa70 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import DefaultDict, Optional +from typing import Optional from pyrit.models import AttackOutcome, AttackResult @@ -54,8 +54,8 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats if not attack_results: raise ValueError("attack_results cannot be empty") - overall_counts: DefaultDict[str, int] = defaultdict(int) - by_type_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict(lambda: defaultdict(int)) + overall_counts: defaultdict[str, int] = defaultdict(int) + by_type_counts: defaultdict[str, defaultdict[str, int]] = defaultdict(lambda: defaultdict(int)) for attack in attack_results: if not isinstance(attack, AttackResult): diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index d902e419c4..a1f526545d 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,7 +5,8 @@ import logging import time -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Union, cast from urllib.parse import urlparse import msal diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 02edf7f611..da9f79e7a5 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -337,11 +337,9 @@ def logits(self, model: Any, test_controls: Any = None, return_ids: bool = False if not (test_ids[0].shape[0] == self._control_slice.stop - self._control_slice.start): raise ValueError( - ( - f"test_controls must have shape " - f"(n, {self._control_slice.stop - self._control_slice.start}), " - f"got {test_ids.shape}" - ) + f"test_controls must have shape " + f"(n, {self._control_slice.stop - self._control_slice.start}), " + f"got {test_ids.shape}" ) locs = ( @@ -583,14 +581,14 @@ def __init__( self, goals: list[str], targets: list[str], - workers: list["ModelWorker"], + workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", test_prefixes: Optional[list[str]] = None, logfile: Optional[str] = None, managers: Optional[dict[str, Any]] = None, test_goals: Optional[list[str]] = None, test_targets: Optional[list[str]] = None, - test_workers: Optional[list["ModelWorker"]] = None, + test_workers: Optional[list[ModelWorker]] = None, ) -> None: """ Initializes the MultiPromptAttack object with the provided parameters. @@ -798,7 +796,7 @@ def control_weight_fn(_: int) -> float: return self.control_str, loss, steps def test( - self, workers: list["ModelWorker"], prompts: list[PromptManager], include_loss: bool = False + self, workers: list[ModelWorker], prompts: list[PromptManager], include_loss: bool = False ) -> tuple[list[list[bool]], list[list[int]], list[list[float]]]: for j, worker in enumerate(workers): worker(prompts[j], "test", worker.model) @@ -874,7 +872,7 @@ def log( tests["n_loss"] = n_loss tests["total"] = total_tests - with open(self.logfile, "r") as f: + with open(self.logfile) as f: log = json.load(f) log["controls"].append(control) @@ -919,7 +917,7 @@ def __init__( self, goals: list[str], targets: list[str], - workers: list["ModelWorker"], + workers: list[ModelWorker], progressive_goals: bool = True, progressive_models: bool = True, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", @@ -928,7 +926,7 @@ def __init__( managers: Optional[dict[str, Any]] = None, test_goals: Optional[list[str]] = None, test_targets: Optional[list[str]] = None, - test_workers: Optional[list["ModelWorker"]] = None, + test_workers: Optional[list[ModelWorker]] = None, **kwargs: Any, ) -> None: """ @@ -1077,7 +1075,7 @@ def run( Whether to filter candidates whose lengths changed after re-tokenization (default is True) """ if self.logfile is not None: - with open(self.logfile, "r") as f: + with open(self.logfile) as f: log = json.load(f) log["params"]["n_steps"] = n_steps @@ -1168,14 +1166,14 @@ def __init__( self, goals: list[str], targets: list[str], - workers: list["ModelWorker"], + workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", test_prefixes: Optional[list[str]] = None, logfile: Optional[str] = None, managers: Optional[dict[str, Any]] = None, test_goals: Optional[list[str]] = None, test_targets: Optional[list[str]] = None, - test_workers: Optional[list["ModelWorker"]] = None, + test_workers: Optional[list[ModelWorker]] = None, **kwargs: Any, ) -> None: """ @@ -1317,7 +1315,7 @@ def run( Whether to filter candidates (default is True) """ if self.logfile is not None: - with open(self.logfile, "r") as f: + with open(self.logfile) as f: log = json.load(f) log["params"]["n_steps"] = n_steps @@ -1381,14 +1379,14 @@ def __init__( self, goals: list[str], targets: list[str], - workers: list["ModelWorker"], + workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", test_prefixes: Optional[list[str]] = None, logfile: Optional[str] = None, managers: Optional[dict[str, Any]] = None, test_goals: Optional[list[str]] = None, test_targets: Optional[list[str]] = None, - test_workers: Optional[list["ModelWorker"]] = None, + test_workers: Optional[list[ModelWorker]] = None, **kwargs: Any, ) -> None: """ @@ -1499,7 +1497,7 @@ def run( tokenizer.padding_side = "left" if self.logfile is not None: - with open(self.logfile, "r") as f: + with open(self.logfile) as f: log = json.load(f) log["params"]["num_tests"] = len(controls) @@ -1630,20 +1628,20 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any] results.put(fn(*args, **kwargs)) tasks.task_done() - def start(self) -> "ModelWorker": + def start(self) -> ModelWorker: self.process = mp.Process(target=ModelWorker.run, args=(self.model, self.tasks, self.results)) self.process.start() logger.info(f"Started worker {self.process.pid} for model {self.model.name_or_path}") return self - def stop(self) -> "ModelWorker": + def stop(self) -> ModelWorker: self.tasks.put(None) if self.process is not None: self.process.join() torch.cuda.empty_cache() return self - def __call__(self, ob: Any, fn: str, *args: Any, **kwargs: Any) -> "ModelWorker": + def __call__(self, ob: Any, fn: str, *args: Any, **kwargs: Any) -> ModelWorker: self.tasks.put((deepcopy(ob), fn, args, kwargs)) return self @@ -1720,8 +1718,8 @@ def get_workers(params: Any, eval: bool = False) -> tuple[list[ModelWorker], lis worker.start() num_train_models = getattr(params, "num_train_models", len(workers)) - logger.info("Loaded {} train models".format(num_train_models)) - logger.info("Loaded {} test models".format(len(workers) - num_train_models)) + logger.info(f"Loaded {num_train_models} train models") + logger.info(f"Loaded {len(workers) - num_train_models} test models") return workers[:num_train_models], workers[num_train_models:] @@ -1763,7 +1761,7 @@ def get_goals_and_targets(params: Any) -> tuple[list[str], list[str], list[str], ) if len(test_goals) != len(test_targets): raise ValueError(f"Length of test_goals ({len(test_goals)}) and test_targets ({len(test_targets)}) must match") - logger.info("Loaded {} train goals".format(len(train_goals))) - logger.info("Loaded {} test goals".format(len(test_goals))) + logger.info(f"Loaded {len(train_goals)} train goals") + logger.info(f"Loaded {len(test_goals)} test goals") return train_goals, train_targets, test_goals, test_targets diff --git a/pyrit/auxiliary_attacks/gcg/experiments/run.py b/pyrit/auxiliary_attacks/gcg/experiments/run.py index 09cf98130c..342db0a67a 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/run.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/run.py @@ -3,7 +3,7 @@ import argparse import os -from typing import Any, Dict, Union +from typing import Any, Union import yaml @@ -24,7 +24,7 @@ def _load_yaml_to_dict(config_path: str) -> dict[str, Any]: Returns: dict[str, Any]: The parsed configuration dictionary. """ - with open(config_path, "r") as f: + with open(config_path) as f: data: dict[str, Any] = yaml.safe_load(f) return data @@ -52,7 +52,7 @@ def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parame hf_token = os.environ.get("HUGGINGFACE_TOKEN") if not hf_token: raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable") - runtime_config: Dict[str, Union[str, bool, Any]] = { + runtime_config: dict[str, Union[str, bool, Any]] = { "train_data": ( "https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv" ), diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index d3592102c5..e910c2b69c 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -12,8 +12,9 @@ import mimetypes import uuid +from collections.abc import Sequence from datetime import datetime, timezone -from typing import Dict, List, Optional, Sequence, cast +from typing import Optional, cast from pyrit.backend.models.attacks import ( AddMessageRequest, @@ -81,7 +82,7 @@ def attack_result_to_summary( ) -def pyrit_scores_to_dto(scores: List[PyritScore]) -> List[Score]: +def pyrit_scores_to_dto(scores: list[PyritScore]) -> list[Score]: """ Translate PyRIT score objects to backend Score DTOs. @@ -121,7 +122,7 @@ def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Opti return mime_type -def pyrit_messages_to_dto(pyrit_messages: List[PyritMessage]) -> List[Message]: +def pyrit_messages_to_dto(pyrit_messages: list[PyritMessage]) -> list[Message]: """ Translate PyRIT messages to backend Message DTOs. @@ -173,7 +174,7 @@ def request_piece_to_pyrit_message_piece( role: ChatMessageRole, conversation_id: str, sequence: int, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> PyritMessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -188,7 +189,7 @@ def request_piece_to_pyrit_message_piece( Returns: PyritMessagePiece domain object. """ - metadata: Optional[Dict[str, str | int]] = {"mime_type": piece.mime_type} if piece.mime_type else None + metadata: Optional[dict[str, str | int]] = {"mime_type": piece.mime_type} if piece.mime_type else None original_prompt_id = uuid.UUID(piece.original_prompt_id) if piece.original_prompt_id else None return PyritMessagePiece( role=role, @@ -209,7 +210,7 @@ def request_to_pyrit_message( request: AddMessageRequest, conversation_id: str, sequence: int, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> PyritMessage: """ Build a PyRIT Message from an AddMessageRequest DTO. @@ -255,7 +256,7 @@ def _get_preview_from_pieces(pieces: Sequence[PyritMessagePiece]) -> Optional[st return text[:100] + "..." if len(text) > 100 else text -def _collect_labels_from_pieces(pieces: Sequence[PyritMessagePiece]) -> Dict[str, str]: +def _collect_labels_from_pieces(pieces: Sequence[PyritMessagePiece]) -> dict[str, str]: """ Collect labels from message pieces. diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index d69ab96df3..f1d097762d 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -5,7 +5,7 @@ Converter mappers – domain → DTO translation for converter-related models. """ -from typing import List, Optional +from typing import Optional from pyrit.backend.models.converters import ConverterInstance from pyrit.prompt_converter import PromptConverter @@ -21,7 +21,7 @@ def converter_object_to_instance( converter_id: str, converter_obj: PromptConverter, *, - sub_converter_ids: Optional[List[str]] = None, + sub_converter_ids: Optional[list[str]] = None, ) -> ConverterInstance: """ Build a ConverterInstance DTO from a registry converter object. diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 9bcf13ae19..9183d933cf 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -9,7 +9,7 @@ """ from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field @@ -46,7 +46,7 @@ class MessagePiece(BaseModel): original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") - scores: List[Score] = Field(default_factory=list, description="Scores embedded in this piece") + scores: list[Score] = Field(default_factory=list, description="Scores embedded in this piece") response_error: PromptResponseError = Field( default="none", description="Error status: none, processing, blocked, empty, unknown" ) @@ -60,7 +60,7 @@ class Message(BaseModel): turn_number: int = Field(..., description="Turn number in the conversation (1-indexed)") role: ChatMessageRole = Field(..., description="Message role") - pieces: List[MessagePiece] = Field(..., description="Message pieces (multimodal support)") + pieces: list[MessagePiece] = Field(..., description="Message pieces (multimodal support)") created_at: datetime = Field(..., description="Message creation timestamp") @@ -74,10 +74,10 @@ class AttackSummary(BaseModel): conversation_id: str = Field(..., description="Unique attack identifier") attack_type: str = Field(..., description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") - attack_specific_params: Optional[Dict[str, Any]] = Field(None, description="Additional attack-specific parameters") + attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") target_unique_name: Optional[str] = Field(None, description="Unique name of the objective target") target_type: Optional[str] = Field(None, description="Target class name (e.g., 'OpenAIChatTarget')") - converters: List[str] = Field( + converters: list[str] = Field( default_factory=list, description="Request converter class names applied in this attack" ) outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( @@ -87,7 +87,7 @@ class AttackSummary(BaseModel): None, description="Preview of the last message (truncated to ~100 chars)" ) message_count: int = Field(0, description="Total number of messages in the attack") - labels: Dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") + labels: dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") created_at: datetime = Field(..., description="Attack creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") @@ -101,7 +101,7 @@ class AttackMessagesResponse(BaseModel): """Response containing all messages for an attack.""" conversation_id: str = Field(..., description="Attack identifier") - messages: List[Message] = Field(default_factory=list, description="All messages in order") + messages: list[Message] = Field(default_factory=list, description="All messages in order") # ============================================================================ @@ -112,14 +112,14 @@ class AttackMessagesResponse(BaseModel): class AttackListResponse(BaseModel): """Paginated response for listing attacks.""" - items: List[AttackSummary] = Field(..., description="List of attack summaries") + items: list[AttackSummary] = Field(..., description="List of attack summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") class AttackOptionsResponse(BaseModel): """Response containing unique attack class names used across attacks.""" - attack_classes: List[str] = Field( + attack_classes: list[str] = Field( ..., description="Sorted list of unique attack class names found in attack results" ) @@ -127,7 +127,7 @@ class AttackOptionsResponse(BaseModel): class ConverterOptionsResponse(BaseModel): """Response containing unique converter class names used across attacks.""" - converter_classes: List[str] = Field( + converter_classes: list[str] = Field( ..., description="Sorted list of unique converter class names found in attack results" ) @@ -160,7 +160,7 @@ class PrependedMessageRequest(BaseModel): """A message to prepend to the attack (for system prompt/branching).""" role: ChatMessageRole = Field(..., description="Message role") - pieces: List[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)", max_length=50) + pieces: list[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)", max_length=50) class CreateAttackRequest(BaseModel): @@ -168,10 +168,10 @@ class CreateAttackRequest(BaseModel): name: Optional[str] = Field(None, description="Attack name/label") target_unique_name: str = Field(..., description="Target instance ID to attack") - prepended_conversation: Optional[List[PrependedMessageRequest]] = Field( + prepended_conversation: Optional[list[PrependedMessageRequest]] = Field( None, description="Messages to prepend (system prompts, branching context)", max_length=200 ) - labels: Optional[Dict[str, str]] = Field(None, description="User-defined labels for filtering") + labels: Optional[dict[str, str]] = Field(None, description="User-defined labels for filtering") class CreateAttackResponse(BaseModel): @@ -207,12 +207,12 @@ class AddMessageRequest(BaseModel): """ role: ChatMessageRole = Field(default="user", description="Message role") - pieces: List[MessagePieceRequest] = Field(..., description="Message pieces", max_length=50) + pieces: list[MessagePieceRequest] = Field(..., description="Message pieces", max_length=50) send: bool = Field( default=True, description="If True, send to target and wait for response. If False, just store in memory.", ) - converter_ids: Optional[List[str]] = Field( + converter_ids: Optional[list[str]] = Field( None, description="Converter instance IDs to apply (overrides attack-level)" ) diff --git a/pyrit/backend/models/common.py b/pyrit/backend/models/common.py index 44203ddcd0..0a2e00e6b5 100644 --- a/pyrit/backend/models/common.py +++ b/pyrit/backend/models/common.py @@ -7,7 +7,7 @@ Includes pagination, error handling (RFC 7807), and shared base models. """ -from typing import Any, List, Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -42,7 +42,7 @@ class ProblemDetail(BaseModel): status: int = Field(..., description="HTTP status code") detail: str = Field(..., description="Human-readable explanation") instance: Optional[str] = Field(None, description="URI of the specific occurrence") - errors: Optional[List[FieldError]] = Field(None, description="Field-level errors for validation") + errors: Optional[list[FieldError]] = Field(None, description="Field-level errors for validation") # Sensitive field patterns to filter from identifiers diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index 27304b3932..1c4e615f87 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -7,7 +7,7 @@ This module defines the Instance models and preview functionality. """ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -35,16 +35,16 @@ class ConverterInstance(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") converter_type: str = Field(..., description="Converter class name (e.g., 'Base64Converter')") display_name: Optional[str] = Field(None, description="Human-readable display name") - supported_input_types: List[str] = Field( + supported_input_types: list[str] = Field( default_factory=list, description="Input data types supported by this converter" ) - supported_output_types: List[str] = Field( + supported_output_types: list[str] = Field( default_factory=list, description="Output data types produced by this converter" ) - converter_specific_params: Optional[Dict[str, Any]] = Field( + converter_specific_params: Optional[dict[str, Any]] = Field( None, description="Additional converter-specific parameters" ) - sub_converter_ids: Optional[List[str]] = Field( + sub_converter_ids: Optional[list[str]] = Field( None, description="Converter IDs of sub-converters (for pipelines/composites)" ) @@ -52,7 +52,7 @@ class ConverterInstance(BaseModel): class ConverterInstanceListResponse(BaseModel): """Response for listing converter instances.""" - items: List[ConverterInstance] = Field(..., description="List of converter instances") + items: list[ConverterInstance] = Field(..., description="List of converter instances") class CreateConverterRequest(BaseModel): @@ -60,7 +60,7 @@ class CreateConverterRequest(BaseModel): type: str = Field(..., description="Converter type (e.g., 'Base64Converter')") display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field( + params: dict[str, Any] = Field( default_factory=dict, description="Converter constructor parameters", ) @@ -95,7 +95,7 @@ class ConverterPreviewRequest(BaseModel): original_value: str = Field(..., description="Text to convert") original_value_data_type: PromptDataType = Field(default="text", description="Data type of original value") - converter_ids: List[str] = Field(..., description="Converter instance IDs to apply") + converter_ids: list[str] = Field(..., description="Converter instance IDs to apply") class ConverterPreviewResponse(BaseModel): @@ -105,4 +105,4 @@ class ConverterPreviewResponse(BaseModel): original_value_data_type: PromptDataType = Field(..., description="Data type of original value") converted_value: str = Field(..., description="Final converted text") converted_value_data_type: PromptDataType = Field(..., description="Data type of converted value") - steps: List[PreviewStep] = Field(..., description="Step-by-step conversion results") + steps: list[PreviewStep] = Field(..., description="Step-by-step conversion results") diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index 6bf46e2407..36b5634680 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -11,7 +11,7 @@ This module defines the Instance models for runtime target management. """ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -35,13 +35,13 @@ class TargetInstance(BaseModel): temperature: Optional[float] = Field(None, description="Temperature parameter for generation") top_p: Optional[float] = Field(None, description="Top-p parameter for generation") max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") - target_specific_params: Optional[Dict[str, Any]] = Field(None, description="Additional target-specific parameters") + target_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional target-specific parameters") class TargetListResponse(BaseModel): """Response for listing target instances.""" - items: List[TargetInstance] = Field(..., description="List of target instances") + items: list[TargetInstance] = Field(..., description="List of target instances") pagination: PaginationInfo = Field(..., description="Pagination metadata") @@ -49,4 +49,4 @@ class CreateTargetRequest(BaseModel): """Request to create a new target instance.""" type: str = Field(..., description="Target type (e.g., 'OpenAIChatTarget')") - params: Dict[str, Any] = Field(default_factory=dict, description="Target constructor parameters") + params: dict[str, Any] = Field(default_factory=dict, description="Target constructor parameters") diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 6b9851f09a..b32fb30978 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -8,7 +8,7 @@ This is the attack-centric API design. """ -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from fastapi import APIRouter, HTTPException, Query, status @@ -30,7 +30,7 @@ router = APIRouter(prefix="/attacks", tags=["attacks"]) -def _parse_labels(label_params: Optional[List[str]]) -> Optional[Dict[str, str]]: +def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str]]: """ Parse label query params in 'key:value' format to a dict. @@ -53,12 +53,12 @@ def _parse_labels(label_params: Optional[List[str]]) -> Optional[Dict[str, str]] ) async def list_attacks( attack_class: Optional[str] = Query(None, description="Filter by exact attack class name"), - converter_classes: Optional[List[str]] = Query( + converter_classes: Optional[list[str]] = Query( None, description="Filter by converter class names (repeatable, AND logic). Pass empty to match no-converter attacks.", ), outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), - label: Optional[List[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), + label: Optional[list[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), min_turns: Optional[int] = Query(None, ge=0, description="Filter by minimum executed turns"), max_turns: Optional[int] = Query(None, ge=0, description="Filter by maximum executed turns"), limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), diff --git a/pyrit/backend/routes/labels.py b/pyrit/backend/routes/labels.py index e2f2d64c40..60b7a635e0 100644 --- a/pyrit/backend/routes/labels.py +++ b/pyrit/backend/routes/labels.py @@ -7,7 +7,7 @@ Provides access to unique label values for filtering in the GUI. """ -from typing import Dict, List, Literal +from typing import Literal from fastapi import APIRouter, Query from pydantic import BaseModel, Field @@ -21,7 +21,7 @@ class LabelOptionsResponse(BaseModel): """Response containing unique label keys and their values.""" source: str = Field(..., description="Source type (e.g., 'attacks')") - labels: Dict[str, List[str]] = Field(..., description="Map of label keys to their unique values") + labels: dict[str, list[str]] = Field(..., description="Map of label keys to their unique values") @router.get( diff --git a/pyrit/backend/routes/version.py b/pyrit/backend/routes/version.py index 5bff355413..e24654b9d1 100644 --- a/pyrit/backend/routes/version.py +++ b/pyrit/backend/routes/version.py @@ -49,7 +49,7 @@ async def get_version_async() -> VersionResponse: build_info_path = Path("/app/build_info.json") if build_info_path.exists(): try: - with open(build_info_path, "r") as f: + with open(build_info_path) as f: build_info = json.load(f) source = build_info.get("source") commit = build_info.get("commit") diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 591ee649c1..04ceef4ef5 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -18,7 +18,7 @@ import uuid from datetime import datetime, timezone from functools import lru_cache -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from pyrit.backend.mappers.attack_mappers import ( attack_result_to_summary, @@ -64,9 +64,9 @@ async def list_attacks_async( self, *, attack_class: Optional[str] = None, - converter_classes: Optional[List[str]] = None, + converter_classes: Optional[list[str]] = None, outcome: Optional[Literal["undetermined", "success", "failure"]] = None, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, min_turns: Optional[int] = None, max_turns: Optional[int] = None, limit: int = 20, @@ -100,7 +100,7 @@ async def list_attacks_async( converter_classes=converter_classes, ) - filtered: List[AttackResult] = [] + filtered: list[AttackResult] = [] for ar in attack_results: if min_turns is not None and ar.executed_turns < min_turns: continue @@ -119,7 +119,7 @@ async def list_attacks_async( next_cursor = page_results[-1].conversation_id if has_more and page_results else None # Phase 2: Fetch pieces only for the page we're returning - page: List[AttackSummary] = [] + page: list[AttackSummary] = [] for ar in page_results: pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) page.append(attack_result_to_summary(ar, pieces=pieces)) @@ -129,7 +129,7 @@ async def list_attacks_async( pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_attack_options_async(self) -> List[str]: + async def get_attack_options_async(self) -> list[str]: """ Get all unique attack class names from stored attack results. @@ -141,7 +141,7 @@ async def get_attack_options_async(self) -> List[str]: """ return self._memory.get_unique_attack_class_names() - async def get_converter_options_async(self) -> List[str]: + async def get_converter_options_async(self) -> list[str]: """ Get all unique converter class names used across attack results. @@ -340,8 +340,8 @@ async def add_message_async(self, *, conversation_id: str, request: AddMessageRe # ======================================================================== def _paginate_attack_results( - self, items: List[AttackResult], cursor: Optional[str], limit: int - ) -> tuple[List[AttackResult], bool]: + self, items: list[AttackResult], cursor: Optional[str], limit: int + ) -> tuple[list[AttackResult], bool]: """ Apply cursor-based pagination over AttackResult objects. @@ -369,8 +369,8 @@ def _paginate_attack_results( async def _store_prepended_messages( self, conversation_id: str, - prepended: List[Any], - labels: Optional[Dict[str, str]] = None, + prepended: list[Any], + labels: Optional[dict[str, str]] = None, ) -> None: """Store prepended conversation messages in memory.""" seq = 0 @@ -393,7 +393,7 @@ async def _send_and_store_message( request: AddMessageRequest, sequence: int, *, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> None: """Send message to target via normalizer and store response.""" target_obj = get_target_service().get_target_object(target_unique_name=target_unique_name) @@ -424,7 +424,7 @@ async def _store_message_only( request: AddMessageRequest, sequence: int, *, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> None: """Store message without sending (send=False).""" for p in request.pieces: @@ -437,7 +437,7 @@ async def _store_message_only( ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) - def _get_converter_configs(self, request: AddMessageRequest) -> List[PromptConverterConfiguration]: + def _get_converter_configs(self, request: AddMessageRequest) -> list[PromptConverterConfiguration]: """ Get converter configurations if needed. diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index c18d0ec084..a0579239cc 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -14,7 +14,7 @@ import uuid from functools import lru_cache -from typing import Any, List, Optional, Tuple +from typing import Any, Optional from pyrit import prompt_converter from pyrit.backend.mappers.converter_mappers import converter_object_to_instance @@ -164,7 +164,7 @@ async def preview_conversion_async(self, *, request: ConverterPreviewRequest) -> steps=steps, ) - def get_converter_objects_for_ids(self, *, converter_ids: List[str]) -> List[Any]: + def get_converter_objects_for_ids(self, *, converter_ids: list[str]) -> list[Any]: """ Get converter objects for a list of IDs. @@ -226,14 +226,14 @@ def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any] resolved["converter"] = conv_obj return resolved - def _gather_converters(self, *, converter_ids: List[str]) -> List[Tuple[str, str, Any]]: + def _gather_converters(self, *, converter_ids: list[str]) -> list[tuple[str, str, Any]]: """ Gather converters to apply from IDs. Returns: List of tuples (converter_id, converter_type, converter_obj). """ - converters: List[Tuple[str, str, Any]] = [] + converters: list[tuple[str, str, Any]] = [] for conv_id in converter_ids: conv_obj = self.get_converter_object(converter_id=conv_id) if conv_obj is None: @@ -245,10 +245,10 @@ def _gather_converters(self, *, converter_ids: List[str]) -> List[Tuple[str, str async def _apply_converters( self, *, - converters: List[Tuple[str, str, Any]], + converters: list[tuple[str, str, Any]], initial_value: str, initial_type: PromptDataType, - ) -> Tuple[List[PreviewStep], str, PromptDataType]: + ) -> tuple[list[PreviewStep], str, PromptDataType]: """ Apply converters and collect steps. @@ -257,7 +257,7 @@ async def _apply_converters( """ current_value = initial_value current_type = initial_type - steps: List[PreviewStep] = [] + steps: list[PreviewStep] = [] for conv_id, conv_type, conv_obj in converters: input_value, input_type = current_value, current_type diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 159c7e7ce2..84d440de15 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -13,7 +13,7 @@ """ from functools import lru_cache -from typing import Any, List, Optional +from typing import Any, Optional from pyrit import prompt_target from pyrit.backend.mappers.target_mappers import target_object_to_instance @@ -119,7 +119,7 @@ async def list_targets_async( ) @staticmethod - def _paginate(items: List[TargetInstance], cursor: Optional[str], limit: int) -> tuple[List[TargetInstance], bool]: + def _paginate(items: list[TargetInstance], cursor: Optional[str], limit: int) -> tuple[list[TargetInstance], bool]: """ Apply cursor-based pagination. diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 717e23a624..809c3450ca 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -18,8 +18,9 @@ import json import logging import sys +from collections.abc import Callable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional from pyrit.setup import ConfigurationLoader from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP @@ -166,7 +167,7 @@ async def initialize_async(self) -> None: self._initialized = True @property - def scenario_registry(self) -> "ScenarioRegistry": + def scenario_registry(self) -> ScenarioRegistry: """ Get the scenario registry. Must call await initialize_async() first. @@ -181,7 +182,7 @@ def scenario_registry(self) -> "ScenarioRegistry": return self._scenario_registry @property - def initializer_registry(self) -> "InitializerRegistry": + def initializer_registry(self) -> InitializerRegistry: """ Get the initializer registry. Must call await initialize_async() first. @@ -213,7 +214,7 @@ async def list_scenarios_async(*, context: FrontendCore) -> list[ScenarioMetadat async def list_initializers_async( *, context: FrontendCore, discovery_path: Optional[Path] = None -) -> "Sequence[InitializerMetadata]": +) -> Sequence[InitializerMetadata]: """ List metadata for all available initializers. @@ -246,7 +247,7 @@ async def run_scenario_async( dataset_names: Optional[list[str]] = None, max_dataset_size: Optional[int] = None, print_summary: bool = True, -) -> "ScenarioResult": +) -> ScenarioResult: """ Run a scenario by name. @@ -457,7 +458,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: print(" Default Datasets: None") -def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") -> None: +def format_initializer_metadata(*, initializer_metadata: InitializerMetadata) -> None: """ Print formatted information about an initializer class. diff --git a/pyrit/common/apply_defaults.py b/pyrit/common/apply_defaults.py index 86bf2d167c..f3770f13bd 100644 --- a/pyrit/common/apply_defaults.py +++ b/pyrit/common/apply_defaults.py @@ -12,8 +12,9 @@ import inspect import logging import sys +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, Type, TypeVar +from typing import Any, TypeVar logger = logging.getLogger(__name__) @@ -57,7 +58,7 @@ class DefaultValueScope: be inherited by subclasses. """ - class_type: Type[object] + class_type: type[object] parameter_name: str include_subclasses: bool = True @@ -81,12 +82,12 @@ class GlobalDefaultValues: def __init__(self) -> None: """Initialize the global default values registry.""" - self._default_values: Dict[DefaultValueScope, Any] = {} + self._default_values: dict[DefaultValueScope, Any] = {} def set_default_value( self, *, - class_type: Type[object], + class_type: type[object], parameter_name: str, value: Any, include_subclasses: bool = True, @@ -111,7 +112,7 @@ def set_default_value( def get_default_value( self, *, - class_type: Type[object], + class_type: type[object], parameter_name: str, ) -> tuple[bool, Any]: """ @@ -150,7 +151,7 @@ def reset_defaults(self) -> None: logger.debug("Reset all default values") @property - def all_defaults(self) -> Dict[DefaultValueScope, Any]: + def all_defaults(self) -> dict[DefaultValueScope, Any]: """Get a copy of all current default values.""" return self._default_values.copy() @@ -171,7 +172,7 @@ def get_global_default_values() -> GlobalDefaultValues: def set_default_value( *, - class_type: Type[object], + class_type: type[object], parameter_name: str, value: Any, include_subclasses: bool = True, diff --git a/pyrit/common/csv_helper.py b/pyrit/common/csv_helper.py index b2bd8bfad4..48a9b9dd7e 100644 --- a/pyrit/common/csv_helper.py +++ b/pyrit/common/csv_helper.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import csv -from typing import IO, Any, Dict, List +from typing import IO, Any -def read_csv(file: IO[Any]) -> List[Dict[str, str]]: +def read_csv(file: IO[Any]) -> list[dict[str, str]]: """ Read a CSV file and return its rows as dictionaries. @@ -16,7 +16,7 @@ def read_csv(file: IO[Any]) -> List[Dict[str, str]]: return list(reader) -def write_csv(file: IO[Any], examples: List[Dict[str, str]]) -> None: +def write_csv(file: IO[Any], examples: list[dict[str, str]]) -> None: """ Write a list of dictionaries to a CSV file. diff --git a/pyrit/common/deprecation.py b/pyrit/common/deprecation.py index 7f9cf65f0c..b730f934fa 100644 --- a/pyrit/common/deprecation.py +++ b/pyrit/common/deprecation.py @@ -4,7 +4,8 @@ from __future__ import annotations import warnings -from typing import Any, Callable +from collections.abc import Callable +from typing import Any def print_deprecation_message( diff --git a/pyrit/common/json_helper.py b/pyrit/common/json_helper.py index 7e9a4adc96..668b4335fb 100644 --- a/pyrit/common/json_helper.py +++ b/pyrit/common/json_helper.py @@ -2,20 +2,20 @@ # Licensed under the MIT license. import json -from typing import IO, Any, Dict, List, cast +from typing import IO, Any, cast -def read_json(file: IO[Any]) -> List[Dict[str, str]]: +def read_json(file: IO[Any]) -> list[dict[str, str]]: """ Read a JSON file and return its content. Returns: List[Dict[str, str]]: Parsed JSON content. """ - return cast(List[Dict[str, str]], json.load(file)) + return cast(list[dict[str, str]], json.load(file)) -def write_json(file: IO[Any], examples: List[Dict[str, str]]) -> None: +def write_json(file: IO[Any], examples: list[dict[str, str]]) -> None: """ Write a list of dictionaries to a JSON file. @@ -26,7 +26,7 @@ def write_json(file: IO[Any], examples: List[Dict[str, str]]) -> None: json.dump(examples, file) -def read_jsonl(file: IO[Any]) -> List[Dict[str, str]]: +def read_jsonl(file: IO[Any]) -> list[dict[str, str]]: """ Read a JSONL file and return its content. @@ -36,7 +36,7 @@ def read_jsonl(file: IO[Any]) -> List[Dict[str, str]]: return [json.loads(line) for line in file] -def write_jsonl(file: IO[Any], examples: List[Dict[str, str]]) -> None: +def write_jsonl(file: IO[Any], examples: list[dict[str, str]]) -> None: """ Write a list of dictionaries to a JSONL file. diff --git a/pyrit/common/singleton.py b/pyrit/common/singleton.py index 1c4239f2a2..858c469d9c 100644 --- a/pyrit/common/singleton.py +++ b/pyrit/common/singleton.py @@ -20,5 +20,5 @@ def __call__(cls, *args: object, **kwargs: object) -> object: The singleton instance if it exists, otherwise creates a new one and returns it. """ if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] diff --git a/pyrit/common/text_helper.py b/pyrit/common/text_helper.py index d8d4391e8f..4dc49a01ba 100644 --- a/pyrit/common/text_helper.py +++ b/pyrit/common/text_helper.py @@ -1,10 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import IO, Any, Dict, List +from typing import IO, Any -def read_txt(file: IO[Any]) -> List[Dict[str, str]]: +def read_txt(file: IO[Any]) -> list[dict[str, str]]: """ Read a TXT file and return its content. @@ -14,7 +14,7 @@ def read_txt(file: IO[Any]) -> List[Dict[str, str]]: return [{"prompt": line.strip()} for line in file.readlines()] -def write_txt(file: IO[Any], examples: List[Dict[str, str]]) -> None: +def write_txt(file: IO[Any], examples: list[dict[str, str]]) -> None: """ Write a list of dictionaries to a TXT file. diff --git a/pyrit/common/tool_configs.py b/pyrit/common/tool_configs.py index aa2a0a2c60..ab56b5a95a 100644 --- a/pyrit/common/tool_configs.py +++ b/pyrit/common/tool_configs.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from enum import Enum -from typing import Any, Dict +from typing import Any class OpenAIToolType(str, Enum): @@ -13,16 +13,16 @@ class OpenAIToolType(str, Enum): FILE_SEARCH = "file_search" -def web_search_tool() -> Dict[str, Any]: +def web_search_tool() -> dict[str, Any]: """Return the configuration for OpenAI's web search tool.""" return {"type": OpenAIToolType.WEB_SEARCH_PREVIEW.value} -def code_interpreter_tool() -> Dict[str, Any]: +def code_interpreter_tool() -> dict[str, Any]: """Return the configuration for OpenAI's code interpreter tool.""" return {"type": OpenAIToolType.CODE_INTERPRETER.value} -def file_search_tool() -> Dict[str, Any]: +def file_search_tool() -> dict[str, Any]: """Return the configuration for OpenAI's file search tool.""" return {"type": OpenAIToolType.FILE_SEARCH.value} diff --git a/pyrit/common/utils.py b/pyrit/common/utils.py index 401361eb79..fe2362ce66 100644 --- a/pyrit/common/utils.py +++ b/pyrit/common/utils.py @@ -8,7 +8,7 @@ import math import random from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def combine_dict( return result -def combine_list(list1: Union[str, List[str]], list2: Union[str, List[str]]) -> list[str]: +def combine_list(list1: Union[str, list[str]], list2: Union[str, list[str]]) -> list[str]: """ Combine two lists or strings into a single list with unique values. @@ -78,7 +78,7 @@ def combine_list(list1: Union[str, List[str]], list2: Union[str, List[str]]) -> return list(set(list1 + list2)) -def get_random_indices(*, start: int, size: int, proportion: float) -> List[int]: +def get_random_indices(*, start: int, size: int, proportion: float) -> list[int]: """ Generate a list of random indices based on the specified proportion of a given size. The indices are selected from the range [start, start + size). @@ -126,7 +126,7 @@ def to_sha256(data: str) -> str: def warn_if_set( - *, config: Any, unused_fields: List[str], log: Union[logging.Logger, logging.LoggerAdapter[logging.Logger]] = logger + *, config: Any, unused_fields: list[str], log: Union[logging.Logger, logging.LoggerAdapter[logging.Logger]] = logger ) -> None: """ Warn about unused parameters in configurations. @@ -168,9 +168,9 @@ def warn_if_set( def get_kwarg_param( *, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], param_name: str, - expected_type: Type[_T], + expected_type: type[_T], required: bool = True, default_value: Optional[_T] = None, ) -> Optional[_T]: diff --git a/pyrit/common/yaml_loadable.py b/pyrit/common/yaml_loadable.py index 7b03a7fc03..c7f4efa739 100644 --- a/pyrit/common/yaml_loadable.py +++ b/pyrit/common/yaml_loadable.py @@ -3,7 +3,7 @@ import abc from pathlib import Path -from typing import Type, TypeVar, Union +from typing import TypeVar, Union import yaml @@ -18,7 +18,7 @@ class YamlLoadable(abc.ABC): """ @classmethod - def from_yaml_file(cls: Type[T], file: Union[Path | str]) -> T: + def from_yaml_file(cls: type[T], file: Union[Path | str]) -> T: """ Create a new object from a YAML file. diff --git a/pyrit/datasets/jailbreak/text_jailbreak.py b/pyrit/datasets/jailbreak/text_jailbreak.py index 9bad3f1a21..94e317d368 100644 --- a/pyrit/datasets/jailbreak/text_jailbreak.py +++ b/pyrit/datasets/jailbreak/text_jailbreak.py @@ -5,7 +5,7 @@ import random import threading from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional from pyrit.common.path import JAILBREAK_TEMPLATES_PATH from pyrit.models import SeedPrompt @@ -18,11 +18,11 @@ class TextJailBreak: A class that manages jailbreak datasets (like DAN, etc.). """ - _template_cache: Optional[Dict[str, List[Path]]] = None + _template_cache: Optional[dict[str, list[Path]]] = None _cache_lock: threading.Lock = threading.Lock() @classmethod - def _scan_template_files(cls) -> Dict[str, List[Path]]: + def _scan_template_files(cls) -> dict[str, list[Path]]: """ Scan the jailbreak templates directory for YAML files. @@ -32,14 +32,14 @@ def _scan_template_files(cls) -> Dict[str, List[Path]]: Returns: Dict[str, List[Path]]: Mapping of filename to list of matching paths. """ - result: Dict[str, List[Path]] = {} + result: dict[str, list[Path]] = {} for path in JAILBREAK_TEMPLATES_PATH.rglob("*.yaml"): if "multi_parameter" not in path.parts: result.setdefault(path.name, []).append(path) return result @classmethod - def _get_template_cache(cls) -> Dict[str, List[Path]]: + def _get_template_cache(cls) -> dict[str, list[Path]]: """ Return the cached filename-to-path lookup, building it on first access. @@ -80,7 +80,7 @@ def _resolve_template_by_name(cls, template_file_name: str) -> Path: return paths[0] @classmethod - def _get_all_template_paths(cls) -> List[Path]: + def _get_all_template_paths(cls) -> list[Path]: """ Return a flat list of all cached template file paths. @@ -173,7 +173,7 @@ def _load_random_template(self) -> None: raise ValueError("No jailbreak template with a single 'prompt' parameter found among available templates.") - def _validate_required_kwargs(self, kwargs: Dict[str, Any]) -> None: + def _validate_required_kwargs(self, kwargs: dict[str, Any]) -> None: """ Verify that all template parameters (except 'prompt') are present in kwargs. @@ -195,7 +195,7 @@ def _validate_required_kwargs(self, kwargs: Dict[str, Any]) -> None: f"Required parameters (excluding 'prompt'): {required_params}" ) - def _apply_extra_kwargs(self, kwargs: Dict[str, Any]) -> None: + def _apply_extra_kwargs(self, kwargs: dict[str, Any]) -> None: """ Apply additional keyword arguments to the template, preserving the prompt placeholder. @@ -208,7 +208,7 @@ def _apply_extra_kwargs(self, kwargs: Dict[str, Any]) -> None: self.template.value = self.template.render_template_value_silent(**kwargs) @classmethod - def get_jailbreak_templates(cls, num_templates: Optional[int] = None) -> List[str]: + def get_jailbreak_templates(cls, num_templates: Optional[int] = None) -> list[str]: """ Retrieve all jailbreaks from the JAILBREAK_TEMPLATES_PATH. diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 193cfc08d0..51d3790fc9 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import logging +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider from pyrit.models import SeedDataset diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index b5aa62ac22..287f42f8bf 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import List, Literal, Optional +from typing import Literal, Optional from datasets import load_dataset @@ -64,7 +64,7 @@ def __init__( self, *, harm_categories: Optional[ - List[ + list[ Literal[ "Controlled/Regulated Substances", "Copyright/Trademark/Plagiarism", diff --git a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py index 160f208fc6..905a973177 100644 --- a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py @@ -3,7 +3,7 @@ import ast import logging -from typing import List, Literal, Optional +from typing import Literal, Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -46,7 +46,7 @@ def __init__( "English", "Hindi", "French", "Spanish", "Arabic", "Russian", "Serbian", "Tagalog" ] = "English", harm_categories: Optional[ - List[ + list[ Literal[ "Bullying & Harassment", "Discrimination & Injustice", diff --git a/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py b/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py index e52fcebb5b..2920c8fffb 100644 --- a/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import logging -from typing import Literal, Sequence +from collections.abc import Sequence +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index 223d10066f..80943ca15e 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import List, Literal, Optional +from typing import Literal, Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( @@ -49,7 +49,7 @@ def __init__( "harmbench_behaviors_multimodal_all.csv" ), source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[List[SemanticCategory]] = None, + categories: Optional[list[SemanticCategory]] = None, ): """ Initialize the HarmBench multimodal dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py index 4bbdb5f5c1..88f78e1129 100644 --- a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import List, Literal, Optional +from typing import Literal, Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -29,7 +29,7 @@ def __init__( source: str = "PKU-Alignment/PKU-SafeRLHF", include_safe_prompts: bool = True, filter_harm_categories: Optional[ - List[ + list[ Literal[ "Animal Abuse", "Copyright Issues", diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 216c6d08e7..ff7babb9bd 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -7,8 +7,9 @@ import logging import tempfile from abc import ABC +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, TextIO, cast +from typing import Any, Literal, Optional, TextIO, cast import requests from datasets import DownloadMode, disable_progress_bars, load_dataset @@ -22,10 +23,10 @@ logger = logging.getLogger(__name__) # Define the type for the file handlers -FileHandlerRead = Callable[[TextIO], List[Dict[str, str]]] -FileHandlerWrite = Callable[[TextIO, List[Dict[str, str]]], None] +FileHandlerRead = Callable[[TextIO], list[dict[str, str]]] +FileHandlerWrite = Callable[[TextIO, list[dict[str, str]]], None] -FILE_TYPE_HANDLERS: Dict[str, Dict[str, Callable[..., Any]]] = { +FILE_TYPE_HANDLERS: dict[str, dict[str, Callable[..., Any]]] = { "json": {"read": read_json, "write": write_json}, "jsonl": {"read": read_jsonl, "write": write_jsonl}, "csv": {"read": read_csv, "write": write_csv}, @@ -75,7 +76,7 @@ def _validate_file_type(self, file_type: str) -> None: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") - def _read_cache(self, *, cache_file: Path, file_type: str) -> List[Dict[str, str]]: + def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ Read data from cache. @@ -91,9 +92,9 @@ def _read_cache(self, *, cache_file: Path, file_type: str) -> List[Dict[str, str """ self._validate_file_type(file_type) with cache_file.open("r", encoding="utf-8") as file: - return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) + return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) - def _write_cache(self, *, cache_file: Path, examples: List[Dict[str, str]], file_type: str) -> None: + def _write_cache(self, *, cache_file: Path, examples: list[dict[str, str]], file_type: str) -> None: """ Write data to cache. @@ -110,7 +111,7 @@ def _write_cache(self, *, cache_file: Path, examples: List[Dict[str, str]], file with cache_file.open("w", encoding="utf-8") as file: FILE_TYPE_HANDLERS[file_type]["write"](file, examples) - def _fetch_from_public_url(self, *, source: str, file_type: str) -> List[Dict[str, str]]: + def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ Fetch examples from a public URL. @@ -129,16 +130,16 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> List[Dict[st if response.status_code == 200: if file_type in FILE_TYPE_HANDLERS: if file_type == "json": - return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) + return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) return cast( - List[Dict[str, str]], + list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") - def _fetch_from_file(self, *, source: str, file_type: str) -> List[Dict[str, str]]: + def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ Fetch examples from a local file. @@ -152,9 +153,9 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> List[Dict[str, str Raises: ValueError: If the file_type is invalid. """ - with open(source, "r", encoding="utf-8") as file: + with open(source, encoding="utf-8") as file: if file_type in FILE_TYPE_HANDLERS: - return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) + return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") @@ -164,7 +165,7 @@ def _fetch_from_url( source: str, source_type: Literal["public_url", "file"] = "public_url", cache: bool = True, - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: """ Fetch examples from a specified source with caching support. diff --git a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py index 1f6ce6e7a5..275baf4aed 100644 --- a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py @@ -3,7 +3,7 @@ import logging import os -from typing import List, Optional +from typing import Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -98,7 +98,7 @@ def __init__( self, *, source: str = "sorry-bench/sorry-bench-202503", - categories: Optional[List[str]] = None, + categories: Optional[list[str]] = None, prompt_style: Optional[str] = None, token: Optional[str] = None, ): diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 8a03f5bd68..25f8291ee9 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import List, Literal, Optional +from typing import Literal, Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( @@ -55,8 +55,8 @@ def __init__( *, source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv", source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[List[VLSUCategory]] = None, - unsafe_grades: Optional[List[str]] = ["unsafe", "borderline"], + categories: Optional[list[VLSUCategory]] = None, + unsafe_grades: Optional[list[str]] = ["unsafe", "borderline"], max_examples: Optional[int] = None, ): """ diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 376ee652f0..1a48cc2afb 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -5,7 +5,7 @@ import inspect import logging from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional from tqdm import tqdm @@ -27,7 +27,7 @@ class SeedDatasetProvider(ABC): - dataset_name property: Human-readable name for the dataset """ - _registry: Dict[str, Type["SeedDatasetProvider"]] = {} + _registry: dict[str, type["SeedDatasetProvider"]] = {} def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -70,7 +70,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: pass @classmethod - def get_all_providers(cls) -> Dict[str, Type["SeedDatasetProvider"]]: + def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ Get all registered dataset provider classes. @@ -80,7 +80,7 @@ def get_all_providers(cls) -> Dict[str, Type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> List[str]: + def get_all_dataset_names(cls) -> list[str]: """ Get the names of all registered datasets. @@ -108,7 +108,7 @@ def get_all_dataset_names(cls) -> List[str]: async def fetch_datasets_async( cls, *, - dataset_names: Optional[List[str]] = None, + dataset_names: Optional[list[str]] = None, cache: bool = True, max_concurrency: int = 5, ) -> list[SeedDataset]: @@ -149,8 +149,8 @@ async def fetch_datasets_async( raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( - provider_name: str, provider_class: Type["SeedDatasetProvider"] - ) -> Optional[Tuple[str, SeedDataset]]: + provider_name: str, provider_class: type["SeedDatasetProvider"] + ) -> Optional[tuple[str, SeedDataset]]: """ Fetch a single dataset with error handling. @@ -176,8 +176,8 @@ async def fetch_single_dataset( pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( - provider_name: str, provider_class: Type["SeedDatasetProvider"] - ) -> Optional[Tuple[str, SeedDataset]]: + provider_name: str, provider_class: type["SeedDatasetProvider"] + ) -> Optional[tuple[str, SeedDataset]]: """ Enforce concurrency limit and update progress during dataset fetch. @@ -199,7 +199,7 @@ async def fetch_with_semaphore( pbar.close() # Merge datasets with the same name - datasets: Dict[str, SeedDataset] = {} + datasets: dict[str, SeedDataset] = {} for result in results: # Skip None results (filtered datasets) if result is None: diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index 5640a6a27b..f6b51a10b8 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import asyncio -from typing import Any, Awaitable, Callable, Optional +from collections.abc import Awaitable, Callable +from typing import Any, Optional import tenacity from openai import AsyncOpenAI diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index a452e63f58..eeb9e16535 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -5,7 +5,8 @@ import logging import os from abc import ABC -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from openai import RateLimitError from tenacity import ( diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 8062c47eaf..5f396ceb77 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -148,4 +148,4 @@ def remove_markdown_json(response_msg: str) -> str: json.loads(response_msg) return response_msg except json.JSONDecodeError: - return "Invalid JSON response: {}".format(response_msg) + return f"Invalid JSON response: {response_msg}" diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 589fc98503..7a27cb5666 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -3,8 +3,9 @@ import logging import uuid +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional from pyrit.common.utils import combine_dict from pyrit.executor.attack.component.prepended_conversation_config import ( @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) -def mark_messages_as_simulated(messages: Sequence[Message]) -> List[Message]: +def mark_messages_as_simulated(messages: Sequence[Message]) -> list[Message]: """ Mark assistant messages as simulated_assistant for traceability. @@ -51,13 +52,13 @@ def mark_messages_as_simulated(messages: Sequence[Message]) -> List[Message]: def get_adversarial_chat_messages( - prepended_conversation: List[Message], + prepended_conversation: list[Message], *, adversarial_chat_conversation_id: str, attack_identifier: ComponentIdentifier, adversarial_chat_target_identifier: ComponentIdentifier, - labels: Optional[Dict[str, str]] = None, -) -> List[Message]: + labels: Optional[dict[str, str]] = None, +) -> list[Message]: """ Transform prepended conversation messages for adversarial chat with swapped roles. @@ -82,13 +83,13 @@ def get_adversarial_chat_messages( if not prepended_conversation: return [] - role_swap: Dict[ChatMessageRole, ChatMessageRole] = { + role_swap: dict[ChatMessageRole, ChatMessageRole] = { "user": "assistant", "assistant": "user", "simulated_assistant": "user", } - result: List[Message] = [] + result: list[Message] = [] for message in prepended_conversation: for piece in message.message_pieces: @@ -118,7 +119,7 @@ def get_adversarial_chat_messages( return result -async def build_conversation_context_string_async(messages: List[Message]) -> str: +async def build_conversation_context_string_async(messages: list[Message]) -> str: """ Build a formatted context string from a list of messages. @@ -139,7 +140,7 @@ async def build_conversation_context_string_async(messages: List[Message]) -> st return await normalizer.normalize_string_async(messages) -def get_prepended_turn_count(prepended_conversation: Optional[List[Message]]) -> int: +def get_prepended_turn_count(prepended_conversation: Optional[list[Message]]) -> int: """ Count the number of turns (assistant responses) in a prepended conversation. @@ -166,7 +167,7 @@ class ConversationState: # Scores from the last assistant message (for attack-specific interpretation) # Used by Crescendo to detect refusals and objective achievement - last_assistant_message_scores: List[Score] = field(default_factory=list) + last_assistant_message_scores: list[Score] = field(default_factory=list) class ConversationManager: @@ -198,7 +199,7 @@ def __init__( self._memory = CentralMemory.get_memory_instance() self._attack_identifier = attack_identifier - def get_conversation(self, conversation_id: str) -> List[Message]: + def get_conversation(self, conversation_id: str) -> list[Message]: """ Retrieve a conversation by its ID. @@ -244,7 +245,7 @@ def set_system_prompt( target: PromptChatTarget, conversation_id: str, system_prompt: str, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> None: """ Set or update the system prompt for a conversation. @@ -268,10 +269,10 @@ async def initialize_context_async( context: "AttackContext[Any]", target: PromptTarget, conversation_id: str, - request_converters: Optional[List[PromptConverterConfiguration]] = None, + request_converters: Optional[list[PromptConverterConfiguration]] = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: Optional[int] = None, - memory_labels: Optional[Dict[str, str]] = None, + memory_labels: Optional[dict[str, str]] = None, ) -> ConversationState: """ Initialize attack context with prepended conversation and merged labels. @@ -343,7 +344,7 @@ async def _handle_non_chat_target_async( self, *, context: "AttackContext[Any]", - prepended_conversation: List[Message], + prepended_conversation: list[Message], config: Optional["PrependedConversationConfig"], ) -> ConversationState: """ @@ -412,9 +413,9 @@ async def _handle_non_chat_target_async( async def add_prepended_conversation_to_memory_async( self, *, - prepended_conversation: List[Message], + prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[List[PromptConverterConfiguration]] = None, + request_converters: Optional[list[PromptConverterConfiguration]] = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: Optional[int] = None, ) -> int: @@ -493,9 +494,9 @@ async def _process_prepended_for_chat_target_async( self, *, context: "AttackContext[Any]", - prepended_conversation: List[Message], + prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[List[PromptConverterConfiguration]], + request_converters: Optional[list[PromptConverterConfiguration]], prepended_conversation_config: Optional["PrependedConversationConfig"], max_turns: Optional[int], ) -> ConversationState: @@ -562,8 +563,8 @@ async def _apply_converters_async( self, *, message: Message, - request_converters: List[PromptConverterConfiguration], - apply_to_roles: Optional[List[ChatMessageRole]], + request_converters: list[PromptConverterConfiguration], + apply_to_roles: Optional[list[ChatMessageRole]], ) -> None: """ Apply converters to message pieces. diff --git a/pyrit/executor/attack/component/prepended_conversation_config.py b/pyrit/executor/attack/component/prepended_conversation_config.py index bb98f5c6cb..c78ffad767 100644 --- a/pyrit/executor/attack/component/prepended_conversation_config.py +++ b/pyrit/executor/attack/component/prepended_conversation_config.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Literal, Optional, get_args +from typing import Literal, Optional, get_args from pyrit.message_normalizer import ( ConversationContextNormalizer, @@ -28,7 +28,7 @@ class PrependedConversationConfig: # Roles for which request converters should be applied to prepended messages. # By default, converters are applied to all roles. # Example: ["user"] to apply converters only to user messages. - apply_converters_to_roles: List[ChatMessageRole] = field(default_factory=lambda: list(get_args(ChatMessageRole))) + apply_converters_to_roles: list[ChatMessageRole] = field(default_factory=lambda: list(get_args(ChatMessageRole))) # Optional normalizer to format conversation history into a single text block. # Must implement MessageStringNormalizer (e.g., TokenizerTemplateNormalizer or ConversationContextNormalizer). @@ -75,7 +75,7 @@ def for_non_chat_target( cls, *, message_normalizer: Optional[MessageStringNormalizer] = None, - apply_converters_to_roles: Optional[List[ChatMessageRole]] = None, + apply_converters_to_roles: Optional[list[ChatMessageRole]] = None, ) -> PrependedConversationConfig: """ Create a configuration for use with non-chat targets. diff --git a/pyrit/executor/attack/core/attack_config.py b/pyrit/executor/attack/core/attack_config.py index db322d24c6..7d128ffd79 100644 --- a/pyrit/executor/attack/core/attack_config.py +++ b/pyrit/executor/attack/core/attack_config.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union from pyrit.executor.core import StrategyConverterConfig from pyrit.models import SeedPrompt @@ -48,7 +48,7 @@ class AttackScoringConfig: refusal_scorer: Optional[TrueFalseScorer] = None # Additional scorers for auxiliary metrics or custom evaluations - auxiliary_scorers: List[Scorer] = field(default_factory=list) + auxiliary_scorers: list[Scorer] = field(default_factory=list) # Whether to use scoring results as feedback for iterative attacks use_score_as_feedback: bool = True diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index cad1a8b5a7..06fa2e6823 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -8,16 +8,13 @@ """ import asyncio +from collections.abc import Iterator, Sequence from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - Iterator, - List, Optional, - Sequence, TypeVar, ) @@ -57,8 +54,8 @@ class AttackExecutorResult(Generic[AttackResultT]): Note: "completed" means the execution finished, not that the attack objective was achieved. """ - completed_results: List[AttackResultT] - incomplete_objectives: List[tuple[str, BaseException]] + completed_results: list[AttackResultT] + incomplete_objectives: list[tuple[str, BaseException]] def __iter__(self) -> Iterator[AttackResultT]: """ @@ -93,7 +90,7 @@ def all_completed(self) -> bool: return len(self.incomplete_objectives) == 0 @property - def exceptions(self) -> List[BaseException]: + def exceptions(self) -> list[BaseException]: """Get all exceptions from incomplete objectives.""" return [exception for _, exception in self.incomplete_objectives] @@ -102,7 +99,7 @@ def raise_if_incomplete(self) -> None: if self.incomplete_objectives: raise self.incomplete_objectives[0][1] - def get_results(self) -> List[AttackResultT]: + def get_results(self) -> list[AttackResultT]: """ Get completed results, raising if any incomplete. @@ -143,7 +140,7 @@ async def execute_attack_from_seed_groups_async( seed_groups: Sequence[SeedAttackGroup], adversarial_chat: Optional["PromptChatTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - field_overrides: Optional[Sequence[Dict[str, Any]]] = None, + field_overrides: Optional[Sequence[dict[str, Any]]] = None, return_partial_on_failure: bool = False, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: @@ -215,7 +212,7 @@ async def execute_attack_async( *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], objectives: Sequence[str], - field_overrides: Optional[Sequence[Dict[str, Any]]] = None, + field_overrides: Optional[Sequence[dict[str, Any]]] = None, return_partial_on_failure: bool = False, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: @@ -252,7 +249,7 @@ async def execute_attack_async( params_type = attack.params_type # Build params list - params_list: List[AttackParameters] = [] + params_list: list[AttackParameters] = [] for i, objective in enumerate(objectives): # Start with broadcast fields fields = dict(broadcast_fields) @@ -315,7 +312,7 @@ def _process_execution_results( self, *, objectives: Sequence[str], - results_or_exceptions: List[Any], + results_or_exceptions: list[Any], return_partial_on_failure: bool, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -332,8 +329,8 @@ def _process_execution_results( Raises: BaseException: If return_partial_on_failure=False and any failed. """ - completed: List[AttackStrategyResultT] = [] - incomplete: List[tuple[str, BaseException]] = [] + completed: list[AttackStrategyResultT] = [] + incomplete: list[tuple[str, BaseException]] = [] for objective, result in zip(objectives, results_or_exceptions): if isinstance(result, BaseException): @@ -362,9 +359,9 @@ async def execute_multi_objective_attack_async( self, *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], - objectives: List[str], - prepended_conversation: Optional[List[Message]] = None, - memory_labels: Optional[Dict[str, str]] = None, + objectives: list[str], + prepended_conversation: Optional[list[Message]] = None, + memory_labels: Optional[dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: @@ -392,7 +389,7 @@ async def execute_multi_objective_attack_async( ) # Build field_overrides if prepended_conversation is provided (broadcast to all) - field_overrides: Optional[List[Dict[str, Any]]] = None + field_overrides: Optional[list[dict[str, Any]]] = None if prepended_conversation: field_overrides = [{"prepended_conversation": prepended_conversation} for _ in objectives] @@ -409,10 +406,10 @@ async def execute_single_turn_attacks_async( self, *, attack: AttackStrategy["_SingleTurnContextT", AttackStrategyResultT], - objectives: List[str], - messages: Optional[List[Message]] = None, - prepended_conversations: Optional[List[List[Message]]] = None, - memory_labels: Optional[Dict[str, str]] = None, + objectives: list[str], + messages: Optional[list[Message]] = None, + prepended_conversations: Optional[list[list[Message]]] = None, + memory_labels: Optional[dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: @@ -450,11 +447,11 @@ async def execute_single_turn_attacks_async( ) # Build field_overrides from per-objective parameters - field_overrides: Optional[List[Dict[str, Any]]] = None + field_overrides: Optional[list[dict[str, Any]]] = None if messages or prepended_conversations: field_overrides = [] for i in range(len(objectives)): - override: Dict[str, Any] = {} + override: dict[str, Any] = {} if messages and i < len(messages): override["next_message"] = messages[i] if prepended_conversations and i < len(prepended_conversations): @@ -474,10 +471,10 @@ async def execute_multi_turn_attacks_async( self, *, attack: AttackStrategy["_MultiTurnContextT", AttackStrategyResultT], - objectives: List[str], - messages: Optional[List[Message]] = None, - prepended_conversations: Optional[List[List[Message]]] = None, - memory_labels: Optional[Dict[str, str]] = None, + objectives: list[str], + messages: Optional[list[Message]] = None, + prepended_conversations: Optional[list[list[Message]]] = None, + memory_labels: Optional[dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: @@ -515,11 +512,11 @@ async def execute_multi_turn_attacks_async( ) # Build field_overrides from per-objective parameters - field_overrides: Optional[List[Dict[str, Any]]] = None + field_overrides: Optional[list[dict[str, Any]]] = None if messages or prepended_conversations: field_overrides = [] for i in range(len(objectives)): - override: Dict[str, Any] = {} + override: dict[str, Any] = {} if messages and i < len(messages): override["next_message"] = messages[i] if prepended_conversations and i < len(prepended_conversations): diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 4b9b36b46e..95635cde3b 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -5,7 +5,7 @@ import dataclasses from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar from pyrit.models import Message, SeedAttackGroup, SeedGroup @@ -36,10 +36,10 @@ class AttackParameters: next_message: Optional[Message] = None # Conversation that is automatically prepended to the target model - prepended_conversation: Optional[List[Message]] = None + prepended_conversation: Optional[list[Message]] = None # Additional labels that can be applied to the prompts throughout the attack - memory_labels: Optional[Dict[str, str]] = field(default_factory=dict) + memory_labels: Optional[dict[str, str]] = field(default_factory=dict) def __str__(self) -> str: """Return a nicely formatted string representation of the attack parameters.""" @@ -75,11 +75,11 @@ def __str__(self) -> str: @classmethod async def from_seed_group_async( - cls: Type[AttackParamsT], + cls: type[AttackParamsT], *, seed_group: SeedAttackGroup, - adversarial_chat: Optional["PromptChatTarget"] = None, - objective_scorer: Optional["TrueFalseScorer"] = None, + adversarial_chat: Optional[PromptChatTarget] = None, + objective_scorer: Optional[TrueFalseScorer] = None, **overrides: Any, ) -> AttackParamsT: """ @@ -126,7 +126,7 @@ async def from_seed_group_async( assert seed_group.objective is not None # Build params dict, only including fields this class accepts - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if "objective" in valid_fields: params["objective"] = seed_group.objective.value @@ -181,7 +181,7 @@ async def from_seed_group_async( return cls(**params) @classmethod - def excluding(cls, *field_names: str) -> Type["AttackParameters"]: + def excluding(cls, *field_names: str) -> type[AttackParameters]: """ Create a new AttackParameters subclass that excludes the specified fields. @@ -208,7 +208,7 @@ def excluding(cls, *field_names: str) -> Type["AttackParameters"]: raise ValueError(f"Cannot exclude non-existent fields: {invalid}. Valid fields: {current_fields}") # Build new fields list excluding the specified ones - new_fields: List[Any] = [] + new_fields: list[Any] = [] for f in dataclasses.fields(cls): if f.name not in field_names: # Preserve field defaults diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 1fa0d8f321..a8a6ecec67 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -59,8 +59,8 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): # Mutable overrides for attacks that generate these values internally _next_message_override: Optional[Message] = None - _prepended_conversation_override: Optional[List[Message]] = None - _memory_labels_override: Optional[Dict[str, str]] = None + _prepended_conversation_override: Optional[list[Message]] = None + _memory_labels_override: Optional[dict[str, str]] = None # Convenience properties that delegate to params or overrides @property @@ -69,7 +69,7 @@ def objective(self) -> str: return self.params.objective @property - def memory_labels(self) -> Dict[str, str]: + def memory_labels(self) -> dict[str, str]: """Additional labels that can be applied to the prompts throughout the attack.""" # Check override first (for attacks that merge labels) if self._memory_labels_override is not None: @@ -77,12 +77,12 @@ def memory_labels(self) -> Dict[str, str]: return self.params.memory_labels or {} @memory_labels.setter - def memory_labels(self, value: Dict[str, str]) -> None: + def memory_labels(self, value: dict[str, str]) -> None: """Set the memory labels (for attacks that merge strategy + context labels).""" self._memory_labels_override = value @property - def prepended_conversation(self) -> List[Message]: + def prepended_conversation(self) -> list[Message]: """Conversation that is automatically prepended to the target model.""" # Check override first (for attacks that generate internally) if self._prepended_conversation_override is not None: @@ -93,7 +93,7 @@ def prepended_conversation(self) -> List[Message]: return [] @prepended_conversation.setter - def prepended_conversation(self, value: List[Message]) -> None: + def prepended_conversation(self, value: list[Message]) -> None: """Set the prepended conversation (for attacks that generate internally).""" self._prepended_conversation_override = value @@ -236,7 +236,7 @@ def __init__( *, objective_target: PromptTarget, context_type: type[AttackStrategyContextT], - params_type: Type[AttackParamsT] = AttackParameters, # type: ignore[assignment] + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[assignment] logger: logging.Logger = logger, ): """ @@ -268,8 +268,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the attack strategy identifier. @@ -287,7 +287,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this attack strategy. """ - all_children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = { + all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { "objective_target": self.get_objective_target().get_identifier(), } @@ -326,7 +326,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() @property - def params_type(self) -> Type[AttackParameters]: + def params_type(self) -> type[AttackParameters]: """ Get the parameters type for this attack strategy. @@ -372,7 +372,7 @@ async def execute_async( *, objective: str, next_message: Optional[Message] = None, - prepended_conversation: Optional[List[Message]] = None, + prepended_conversation: Optional[list[Message]] = None, memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> AttackStrategyResultT: ... diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index feabb98215..6fce7441c2 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -5,7 +5,7 @@ import textwrap from dataclasses import dataclass, field from string import Formatter -from typing import Any, List, Optional +from typing import Any, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.exceptions import ComponentRole, execution_context @@ -43,7 +43,7 @@ class ChunkedRequestAttackContext(MultiTurnAttackContext[Any]): """Context for the ChunkedRequest attack strategy.""" # Collected chunk responses - chunk_responses: List[str] = field(default_factory=list) + chunk_responses: list[str] = field(default_factory=list) class ChunkedRequestAttack(MultiTurnAttackStrategy[ChunkedRequestAttackContext, AttackResult]): @@ -186,7 +186,7 @@ def _validate_context(self, *, context: ChunkedRequestAttackContext) -> None: if not context.objective or context.objective.isspace(): raise ValueError("Attack objective must be provided and non-empty in the context") - def _generate_chunk_prompts(self, context: ChunkedRequestAttackContext) -> List[str]: + def _generate_chunk_prompts(self, context: ChunkedRequestAttackContext) -> list[str]: """ Generate chunk request prompts based on the configured strategy. diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index b005e21218..b36dc2b28c 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -3,9 +3,10 @@ import json import logging +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Optional, Union, cast from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 873a2fa7f1..e27a5938fd 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Type +from typing import TYPE_CHECKING, Any, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import get_kwarg_param @@ -46,11 +46,11 @@ class MultiPromptSendingAttackParameters(AttackParameters): Only accepts objective and user_messages fields. """ - user_messages: Optional[List[Message]] = None + user_messages: Optional[list[Message]] = None @classmethod async def from_seed_group_async( - cls: Type["MultiPromptSendingAttackParameters"], + cls: type["MultiPromptSendingAttackParameters"], seed_group: SeedAttackGroup, *, adversarial_chat: Optional["PromptChatTarget"] = None, diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 6de7127969..2bcdeb391e 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, Type, TypeVar +from typing import Any, Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -71,7 +71,7 @@ def __init__( *, objective_target: PromptTarget, context_type: type[MultiTurnAttackStrategyContextT], - params_type: Type[AttackParamsT] = AttackParameters, # type: ignore[assignment] + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[assignment] logger: logging.Logger = logger, ): """ diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 132abc0a4a..4e1b009d09 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -5,8 +5,9 @@ import enum import logging +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index a8144e2d01..cb77e55984 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -12,7 +12,7 @@ import logging from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -41,7 +41,7 @@ async def generate_simulated_conversation_async( next_message_system_prompt_path: Optional[Union[str, Path]] = None, attack_converter_config: Optional[AttackConverterConfig] = None, memory_labels: Optional[dict[str, str]] = None, -) -> List[SeedPrompt]: +) -> list[SeedPrompt]: """ Generate a simulated conversation between an adversarial chat and a target. @@ -124,7 +124,7 @@ async def generate_simulated_conversation_async( logger.info(f"Generating {num_turns}-turn simulated conversation for objective: {objective[:50]}...") # Build prepended_conversation - only include system message if prompt is provided - prepended_conversation: List[Message] = [] + prepended_conversation: list[Message] = [] if simulated_target_system_prompt: prepended_conversation.append(Message.from_system_prompt(simulated_target_system_prompt)) @@ -140,7 +140,7 @@ async def generate_simulated_conversation_async( # Filter out system messages - keep the actual conversation # System prompts are set separately on each target during attack execution - conversation_messages: List[Message] = [msg for msg in raw_messages if msg.api_role != "system"] + conversation_messages: list[Message] = [msg for msg in raw_messages if msg.api_role != "system"] # If next_message_system_prompt_path is provided, generate a final user message if next_message_system_prompt_path: @@ -166,7 +166,7 @@ async def generate_simulated_conversation_async( async def _generate_next_message_async( *, objective: str, - conversation_messages: List[Message], + conversation_messages: list[Message], adversarial_chat: PromptChatTarget, next_message_system_prompt_path: Union[str, Path], ) -> Message: @@ -217,7 +217,7 @@ async def _generate_next_message_async( conversation_id=request_message.conversation_id, ) - responses: List[Message] = await adversarial_chat.send_prompt_async(message=request_message) + responses: list[Message] = await adversarial_chat.send_prompt_async(message=request_message) if not responses: raise ValueError("No response received from adversarial chat when generating next message") diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index ac21f8f968..906c2e4ca7 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, cast, overload +from typing import Any, Optional, cast, overload from treelib.tree import Tree @@ -85,7 +85,7 @@ def __init__( *, objective_scorer: FloatScaleThresholdScorer, refusal_scorer: Optional[TrueFalseScorer] = None, - auxiliary_scorers: Optional[List[Scorer]] = None, + auxiliary_scorers: Optional[list[Scorer]] = None, use_score_as_feedback: bool = True, ) -> None: """ @@ -146,7 +146,7 @@ class TAPAttackContext(MultiTurnAttackContext[Any]): # Nodes in the attack tree # Each node represents a branch in the attack tree with its own state - nodes: List["_TreeOfAttacksNode"] = field(default_factory=list) + nodes: list["_TreeOfAttacksNode"] = field(default_factory=list) # Best conversation ID and score found during the attack best_conversation_id: Optional[str] = None @@ -204,12 +204,12 @@ def max_depth_reached(self, value: int) -> None: self.metadata["max_depth_reached"] = value @property - def auxiliary_scores_summary(self) -> Dict[str, float]: + def auxiliary_scores_summary(self) -> dict[str, float]: """Get a summary of auxiliary scores from the best node.""" - return cast(Dict[str, float], self.metadata.get("auxiliary_scores_summary", {})) + return cast(dict[str, float], self.metadata.get("auxiliary_scores_summary", {})) @auxiliary_scores_summary.setter - def auxiliary_scores_summary(self, value: Dict[str, float]) -> None: + def auxiliary_scores_summary(self, value: dict[str, float]) -> None: """Set the auxiliary scores summary.""" self.metadata["auxiliary_scores_summary"] = value @@ -265,9 +265,9 @@ def __init__( desired_response_prefix: str, objective_scorer: Scorer, on_topic_scorer: Optional[Scorer], - request_converters: List[PromptConverterConfiguration], - response_converters: List[PromptConverterConfiguration], - auxiliary_scorers: Optional[List[Scorer]], + request_converters: list[PromptConverterConfiguration], + response_converters: list[PromptConverterConfiguration], + auxiliary_scorers: Optional[list[Scorer]], attack_id: ComponentIdentifier, attack_strategy_name: str, memory_labels: Optional[dict[str, str]] = None, @@ -330,7 +330,7 @@ def __init__( self.completed = False self.off_topic = False self.objective_score: Optional[Score] = None - self.auxiliary_scores: Dict[str, Score] = {} + self.auxiliary_scores: dict[str, Score] = {} self.last_prompt_sent: Optional[str] = None self.last_response: Optional[str] = None self.error_message: Optional[str] = None @@ -348,7 +348,7 @@ def __init__( async def initialize_with_prepended_conversation_async( self, *, - prepended_conversation: List[Message], + prepended_conversation: list[Message], prepended_conversation_config: Optional["PrependedConversationConfig"] = None, ) -> None: """ @@ -1868,7 +1868,7 @@ def _create_attack_node( return node - def _get_completed_nodes_sorted_by_score(self, nodes: List[_TreeOfAttacksNode]) -> List[_TreeOfAttacksNode]: + def _get_completed_nodes_sorted_by_score(self, nodes: list[_TreeOfAttacksNode]) -> list[_TreeOfAttacksNode]: """ Get completed, on-topic nodes sorted by score in descending order. @@ -2100,7 +2100,7 @@ def _get_last_response_from_conversation(self, conversation_id: Optional[str]) - responses = self._memory.get_message_pieces(conversation_id=conversation_id) return responses[-1] if responses else None - def _get_auxiliary_scores_summary(self, nodes: List[_TreeOfAttacksNode]) -> Dict[str, float]: + def _get_auxiliary_scores_summary(self, nodes: list[_TreeOfAttacksNode]) -> dict[str, float]: """ Extract auxiliary scores from the best node if available. @@ -2120,7 +2120,7 @@ def _get_auxiliary_scores_summary(self, nodes: List[_TreeOfAttacksNode]) -> Dict return {name: float(score.get_value()) for name, score in nodes[0].auxiliary_scores.items()} - def _calculate_tree_statistics(self, tree_visualization: Tree) -> Dict[str, int]: + def _calculate_tree_statistics(self, tree_visualization: Tree) -> dict[str, int]: """ Calculate statistics from the tree visualization. diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 5d3ca0728d..357abe55ae 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -3,7 +3,6 @@ import os from datetime import datetime -from typing import List from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter from pyrit.memory import CentralMemory @@ -31,7 +30,7 @@ def __init__(self, *, display_inline: bool = True): self._memory = CentralMemory.get_memory_instance() self._display_inline = display_inline - def _render_markdown(self, markdown_lines: List[str]) -> None: + def _render_markdown(self, markdown_lines: list[str]) -> None: """ Render the markdown content using appropriate display method. @@ -209,7 +208,7 @@ async def print_summary_async(self, result: AttackResult) -> None: async def _get_conversation_markdown_async( self, *, result: AttackResult, include_scores: bool = False - ) -> List[str]: + ) -> list[str]: """ Generate markdown lines for the conversation history. @@ -260,7 +259,7 @@ async def _get_conversation_markdown_async( return markdown_lines - def _format_system_message(self, message: Message) -> List[str]: + def _format_system_message(self, message: Message) -> list[str]: """ Format a system message as markdown. @@ -278,7 +277,7 @@ def _format_system_message(self, message: Message) -> List[str]: lines.append(f"{piece.converted_value}\n") return lines - async def _format_user_message_async(self, *, message: Message, turn_number: int) -> List[str]: + async def _format_user_message_async(self, *, message: Message, turn_number: int) -> list[str]: """ Format a user message as markdown with turn numbering. @@ -300,7 +299,7 @@ async def _format_user_message_async(self, *, message: Message, turn_number: int return lines - async def _format_assistant_message_async(self, *, message: Message) -> List[str]: + async def _format_assistant_message_async(self, *, message: Message) -> list[str]: """ Format an assistant or system response message as markdown. @@ -343,7 +342,7 @@ def _get_audio_mime_type(self, *, audio_path: str) -> str: return "audio/mp4" return "audio/mpeg" # Default fallback for .mp3, .mpeg, and unknown formats - def _format_image_content(self, *, image_path: str) -> List[str]: + def _format_image_content(self, *, image_path: str) -> list[str]: """ Format image content as markdown. @@ -357,7 +356,7 @@ def _format_image_content(self, *, image_path: str) -> List[str]: posix_path = relative_path.replace("\\", "/") return [f"![Image]({posix_path})\n"] - def _format_audio_content(self, *, audio_path: str) -> List[str]: + def _format_audio_content(self, *, audio_path: str) -> list[str]: """ Format audio content as HTML5 audio player. @@ -378,7 +377,7 @@ def _format_audio_content(self, *, audio_path: str) -> List[str]: return lines - def _format_error_content(self, *, piece: MessagePiece) -> List[str]: + def _format_error_content(self, *, piece: MessagePiece) -> list[str]: """ Format error response content with proper styling. @@ -397,7 +396,7 @@ def _format_error_content(self, *, piece: MessagePiece) -> List[str]: return lines - def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> List[str]: + def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> list[str]: """ Format regular text content. @@ -419,7 +418,7 @@ def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> L return lines - async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> List[str]: + async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> list[str]: """ Format a single piece content based on its data type. @@ -442,7 +441,7 @@ async def _format_piece_content_async(self, *, piece: MessagePiece, show_origina return self._format_error_content(piece=piece) return self._format_text_content(piece=piece, show_original=show_original) - def _format_message_scores(self, message: Message) -> List[str]: + def _format_message_scores(self, message: Message) -> list[str]: """ Format scores for all pieces in a message as markdown. @@ -467,7 +466,7 @@ def _format_message_scores(self, message: Message) -> List[str]: lines.append("") return lines - async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]: + async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: """ Generate markdown lines for the attack summary. @@ -517,7 +516,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]: return markdown_lines - async def _get_pruned_conversations_markdown_async(self, result: AttackResult) -> List[str]: + async def _get_pruned_conversations_markdown_async(self, result: AttackResult) -> list[str]: """ Generate markdown lines for pruned conversations. @@ -578,7 +577,7 @@ async def _get_pruned_conversations_markdown_async(self, result: AttackResult) - return markdown_lines - async def _get_adversarial_conversation_markdown_async(self, result: AttackResult) -> List[str]: + async def _get_adversarial_conversation_markdown_async(self, result: AttackResult) -> list[str]: """ Generate markdown lines for the adversarial conversation. diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index d18dba6013..2dbdffd166 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import Any, Optional, Type +from typing import Any, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import warn_if_set @@ -58,7 +58,7 @@ def __init__( attack_scoring_config: Optional[AttackScoringConfig] = None, prompt_normalizer: Optional[PromptNormalizer] = None, max_attempts_on_failure: int = 0, - params_type: Type[AttackParamsT] = AttackParameters, # type: ignore[assignment] + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[assignment] prepended_conversation_config: Optional[PrependedConversationConfig] = None, ) -> None: """ diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index 7a8ff9d399..e3c0560bf5 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -48,7 +48,7 @@ def __init__( *, objective_target: PromptTarget, context_type: type[SingleTurnAttackContext[Any]] = SingleTurnAttackContext, - params_type: Type[AttackParamsT] = AttackParameters, # type: ignore[assignment] + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[assignment] logger: logging.Logger = logger, ): """ diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 341caa7910..f728d27dba 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -6,7 +6,7 @@ import uuid from collections import Counter from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, cast, overload +from typing import Any, Optional, cast, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -47,16 +47,16 @@ class FairnessBiasBenchmarkContext(StrategyContext): objective: Optional[str] = None # Prepended conversation for context - prepended_conversation: List[Message] = field(default_factory=list) + prepended_conversation: list[Message] = field(default_factory=list) # Memory labels for tracking - memory_labels: Dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) # Generated fields for the benchmark run generated_objective: str = field(default_factory=str) generated_story_prompt: str = field(default_factory=str) generated_message: Optional[Message] = None - experiment_results: List[Dict[str, str]] = field(default_factory=list) + experiment_results: list[dict[str, str]] = field(default_factory=list) class FairnessBiasBenchmark(Strategy[FairnessBiasBenchmarkContext, AttackResult]): @@ -226,7 +226,7 @@ async def _run_experiment(self, context: FairnessBiasBenchmarkContext) -> Attack def _format_experiment_results( self, context: FairnessBiasBenchmarkContext, attack_result: AttackResult, experiment_num: int - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Format the experiment data into a dictionary. @@ -330,7 +330,7 @@ def _extract_name_with_regex(self, response: str) -> Optional[str]: return " ".join(cap_words) return None - def get_experiment_summary(self, *, context: FairnessBiasBenchmarkContext) -> Dict[str, Any]: + def get_experiment_summary(self, *, context: FairnessBiasBenchmarkContext) -> dict[str, Any]: """ Get a summary of the experiment results. @@ -388,8 +388,8 @@ async def execute_async( story_type: str, num_experiments: int = 1, objective: Optional[str] = None, - prepended_conversation: Optional[List[Message]] = None, - memory_labels: Optional[Dict[str, str]] = None, + prepended_conversation: Optional[list[Message]] = None, + memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> AttackResult: ... diff --git a/pyrit/executor/benchmark/question_answering.py b/pyrit/executor/benchmark/question_answering.py index d2a244a381..ddf128623e 100644 --- a/pyrit/executor/benchmark/question_answering.py +++ b/pyrit/executor/benchmark/question_answering.py @@ -4,7 +4,7 @@ import logging import textwrap from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, overload +from typing import Any, Optional, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -34,10 +34,10 @@ class QuestionAnsweringBenchmarkContext(StrategyContext): question_answering_entry: QuestionAnsweringEntry # Prepended conversation for context - prepended_conversation: List[Message] = field(default_factory=list) + prepended_conversation: list[Message] = field(default_factory=list) # Memory labels for tracking - memory_labels: Dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) # Generated fields for the benchmark run # The generated objective for the benchmark @@ -260,8 +260,8 @@ async def execute_async( self, *, question_answering_entry: QuestionAnsweringEntry, - prepended_conversation: Optional[List[Message]] = None, - memory_labels: Optional[Dict[str, str]] = None, + prepended_conversation: Optional[list[Message]] = None, + memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> AttackResult: ... diff --git a/pyrit/executor/core/config.py b/pyrit/executor/core/config.py index 1c6ec694be..5b8657dcdd 100644 --- a/pyrit/executor/core/config.py +++ b/pyrit/executor/core/config.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass, field -from typing import List from pyrit.prompt_normalizer import ( PromptConverterConfiguration, @@ -19,7 +18,7 @@ class StrategyConverterConfig: """ # List of converter configurations to apply to target requests/prompts - request_converters: List[PromptConverterConfiguration] = field(default_factory=list) + request_converters: list[PromptConverterConfiguration] = field(default_factory=list) # List of converter configurations to apply to target responses - response_converters: List[PromptConverterConfiguration] = field(default_factory=list) + response_converters: list[PromptConverterConfiguration] = field(default_factory=list) diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 1ef0f94cff..9d965865e6 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -8,11 +8,12 @@ import logging import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, MutableMapping from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import Any, AsyncIterator, Dict, Generic, MutableMapping, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar from pyrit.common import default_values from pyrit.common.logger import logger @@ -160,7 +161,7 @@ def __init__( """ self._id = uuid.uuid4() self._context_type = context_type - self._event_handlers: Dict[str, StrategyEventHandler[StrategyContextT, StrategyResultT]] = {} + self._event_handlers: dict[str, StrategyEventHandler[StrategyContextT, StrategyResultT]] = {} if event_handler is not None: self._register_event_handler(event_handler) @@ -172,7 +173,7 @@ def __init__( StrategyLogAdapter._STRATEGY_ID_KEY: str(self._id)[:8], }, ) - self._memory_labels: Dict[str, str] = ast.literal_eval( + self._memory_labels: dict[str, str] = ast.literal_eval( default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS") or "{}" ) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 965e852d33..f33d8b7ae4 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Union, overload +from typing import Any, Optional, Union, overload import yaml @@ -39,7 +39,7 @@ class AnecdoctorContext(PromptGeneratorStrategyContext): """ # The data in ClaimsReview format to use in constructing the prompt - evaluation_data: List[str] + evaluation_data: list[str] # The language of the content to generate (e.g., "english", "spanish") language: str @@ -51,7 +51,7 @@ class AnecdoctorContext(PromptGeneratorStrategyContext): conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Optional memory labels to apply to the prompts - memory_labels: Dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) @dataclass @@ -138,8 +138,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. @@ -152,7 +152,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = { + all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -331,7 +331,7 @@ def _load_prompt_from_yaml(self, *, yaml_filename: str) -> str: yaml_data = yaml.safe_load(prompt_data) return str(yaml_data["value"]) - def _format_few_shot_examples(self, *, evaluation_data: List[str]) -> str: + def _format_few_shot_examples(self, *, evaluation_data: list[str]) -> str: """ Format the evaluation data as few-shot examples. @@ -404,7 +404,7 @@ async def execute_async( *, content_type: str, language: str, - evaluation_data: List[str], + evaluation_data: list[str], memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> AnecdoctorResult: ... diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 6617140e4f..57593ba3a9 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -8,7 +8,7 @@ import textwrap import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union, overload +from typing import Any, Optional, Union, overload import numpy as np from colorama import Fore, Style @@ -106,7 +106,7 @@ def __init__( self.minimum_reward = minimum_reward self.non_leaf_node_probability = non_leaf_node_probability - def select_node(self, *, initial_nodes: List[_PromptNode], step: int) -> Tuple[_PromptNode, List[_PromptNode]]: + def select_node(self, *, initial_nodes: list[_PromptNode], step: int) -> tuple[_PromptNode, list[_PromptNode]]: """ Select a node using MCTS-explore algorithm. @@ -153,7 +153,7 @@ def _calculate_uct_score(self, *, node: _PromptNode, step: int) -> float: exploration = self.frequency_weight * np.sqrt(2 * np.log(step) / (node.visited_num + 0.01)) return float(exploitation + exploration) - def update_rewards(self, path: List[_PromptNode], reward: float, last_node: Optional[_PromptNode] = None) -> None: + def update_rewards(self, path: list[_PromptNode], reward: float, last_node: Optional[_PromptNode] = None) -> None: """ Update rewards for nodes in the path. @@ -179,24 +179,24 @@ class FuzzerContext(PromptGeneratorStrategyContext): """ # Per-execution input data - prompts: List[str] - prompt_templates: List[str] + prompts: list[str] + prompt_templates: list[str] max_query_limit: Optional[int] = None # Tracking state total_target_query_count: int = 0 total_jailbreak_count: int = 0 - jailbreak_conversation_ids: List[Union[str, uuid.UUID]] = field(default_factory=list) + jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = field(default_factory=list) executed_turns: int = 0 # Tree structure - initial_prompt_nodes: List[_PromptNode] = field(default_factory=list) - new_prompt_nodes: List[_PromptNode] = field(default_factory=list) - mcts_selected_path: List[_PromptNode] = field(default_factory=list) + initial_prompt_nodes: list[_PromptNode] = field(default_factory=list) + new_prompt_nodes: list[_PromptNode] = field(default_factory=list) + mcts_selected_path: list[_PromptNode] = field(default_factory=list) last_choice_node: Optional[_PromptNode] = None # Optional memory labels to apply to the prompts - memory_labels: Dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: """ @@ -219,8 +219,8 @@ class FuzzerResult(PromptGeneratorStrategyResult): """ # Concrete fields instead of metadata storage - successful_templates: List[str] = field(default_factory=list) - jailbreak_conversation_ids: List[Union[str, uuid.UUID]] = field(default_factory=list) + successful_templates: list[str] = field(default_factory=list) + jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = field(default_factory=list) total_queries: int = 0 templates_explored: int = 0 @@ -536,7 +536,7 @@ class FuzzerGenerator( def with_default_scorer( *, objective_target: PromptTarget, - template_converters: List[FuzzerConverter], + template_converters: list[FuzzerConverter], scoring_target: PromptChatTarget, converter_config: Optional[StrategyConverterConfig] = None, prompt_normalizer: Optional[PromptNormalizer] = None, @@ -546,7 +546,7 @@ def with_default_scorer( non_leaf_node_probability: float = _DEFAULT_NON_LEAF_PROBABILITY, batch_size: int = _DEFAULT_BATCH_SIZE, target_jailbreak_goal_count: int = _DEFAULT_TARGET_JAILBREAK_COUNT, - ) -> "FuzzerGenerator": + ) -> FuzzerGenerator: """ Create a FuzzerGenerator instance with default scoring configuration. @@ -604,7 +604,7 @@ def __init__( self, *, objective_target: PromptTarget, - template_converters: List[FuzzerConverter], + template_converters: list[FuzzerConverter], converter_config: Optional[StrategyConverterConfig] = None, scorer: Optional[Scorer] = None, scoring_success_threshold: float = 0.8, @@ -682,8 +682,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. @@ -696,7 +696,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = { + all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -715,7 +715,7 @@ def _build_identifier(self) -> ComponentIdentifier: def _validate_inputs( self, *, - template_converters: List[FuzzerConverter], + template_converters: list[FuzzerConverter], batch_size: int, ) -> None: """ @@ -889,7 +889,7 @@ def _should_stop_generation(self, context: FuzzerContext) -> bool: return False - def _select_template_with_mcts(self, context: FuzzerContext) -> Tuple[_PromptNode, List[_PromptNode]]: + def _select_template_with_mcts(self, context: FuzzerContext) -> tuple[_PromptNode, list[_PromptNode]]: """ Select a template using the MCTS-explore algorithm. @@ -946,7 +946,7 @@ async def _apply_template_converter_async(self, *, context: FuzzerContext, curre return converted.output_text - def _get_other_templates(self, context: FuzzerContext) -> List[str]: + def _get_other_templates(self, context: FuzzerContext) -> list[str]: """ Get templates not in the current MCTS path. @@ -965,7 +965,7 @@ def _get_other_templates(self, context: FuzzerContext) -> List[str]: return other_templates - def _generate_prompts_from_template(self, *, template: SeedPrompt, prompts: List[str]) -> List[str]: + def _generate_prompts_from_template(self, *, template: SeedPrompt, prompts: list[str]) -> list[str]: """ Generate jailbreak prompts by filling template with prompts. @@ -985,7 +985,7 @@ def _generate_prompts_from_template(self, *, template: SeedPrompt, prompts: List return [template.render_template_value(prompt=prompt) for prompt in prompts] - async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts: List[str]) -> List[Message]: + async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts: list[str]) -> list[Message]: """ Send prompts to the target in batches. @@ -1006,7 +1006,7 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts batch_size=self._batch_size, ) - def _create_normalizer_requests(self, prompts: List[str]) -> List[NormalizerRequest]: + def _create_normalizer_requests(self, prompts: list[str]) -> list[NormalizerRequest]: """ Create normalizer requests from prompts. @@ -1016,7 +1016,7 @@ def _create_normalizer_requests(self, prompts: List[str]) -> List[NormalizerRequ Returns: List of normalizer requests. """ - requests: List[NormalizerRequest] = [] + requests: list[NormalizerRequest] = [] for prompt in prompts: seed_group = SeedGroup(seeds=[SeedPrompt(value=prompt, data_type="text")]) @@ -1029,7 +1029,7 @@ def _create_normalizer_requests(self, prompts: List[str]) -> List[NormalizerRequ return requests - async def _score_responses_async(self, *, responses: List[Message], tasks: List[str]) -> List[Score]: + async def _score_responses_async(self, *, responses: list[Message], tasks: list[str]) -> list[Score]: """ Score the responses from the target. @@ -1054,8 +1054,8 @@ def _process_scoring_results( self, *, context: FuzzerContext, - scores: List[Score], - responses: List[Message], + scores: list[Score], + responses: list[Message], template_node: _PromptNode, current_seed: _PromptNode, ) -> int: @@ -1189,8 +1189,8 @@ def _create_generation_result(self, context: FuzzerContext) -> FuzzerResult: async def execute_async( self, *, - prompts: List[str], - prompt_templates: List[str], + prompts: list[str], + prompt_templates: list[str], max_query_limit: Optional[int] = None, memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 8001aec9c7..2f2b44f4a7 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -4,7 +4,7 @@ import pathlib import random import uuid -from typing import Any, List, Optional +from typing import Any, Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +27,7 @@ def __init__( *, converter_target: Optional[PromptChatTarget] = None, prompt_template: Optional[SeedPrompt] = None, - prompt_templates: Optional[List[str]] = None, + prompt_templates: Optional[list[str]] = None, ): """ Initialize the converter with the specified chat target and prompt templates. diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 4b92a54cbc..6326ecbfa7 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union, overload +from typing import Any, Optional, Protocol, Union, overload from pyrit.common.utils import combine_dict, get_kwarg_param from pyrit.executor.core import StrategyConverterConfig @@ -77,7 +77,7 @@ class XPIAContext(WorkflowContext): processing_prompt: Optional[Message] = None # Additional labels that can be applied throughout the workflow - memory_labels: Dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) @dataclass @@ -178,8 +178,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the identifier for this XPIA workflow. @@ -192,7 +192,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this XPIA workflow. """ - all_children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = { + all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { "attack_setup_target": self._attack_setup_target.get_identifier(), } if self._scorer: @@ -424,7 +424,7 @@ async def execute_async( attack_content: Message, processing_callback: Optional[XPIAProcessingCallback] = None, processing_prompt: Optional[Message] = None, - memory_labels: Optional[Dict[str, str]] = None, + memory_labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> XPIAResult: ... diff --git a/pyrit/identifiers/component_identifier.py b/pyrit/identifiers/component_identifier.py index 27707735a4..fe306053ae 100644 --- a/pyrit/identifiers/component_identifier.py +++ b/pyrit/identifiers/component_identifier.py @@ -21,7 +21,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Optional, Union import pyrit from pyrit.common.deprecation import print_deprecation_message @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) -def config_hash(config_dict: Dict[str, Any]) -> str: +def config_hash(config_dict: dict[str, Any]) -> str: """ Compute a deterministic SHA256 hash from a config dictionary. @@ -54,9 +54,9 @@ def _build_hash_dict( *, class_name: str, class_module: str, - params: Dict[str, Any], - children: Dict[str, Any], -) -> Dict[str, Any]: + params: dict[str, Any], + children: dict[str, Any], +) -> dict[str, Any]: """ Build the canonical dictionary used for hash computation. @@ -73,7 +73,7 @@ def _build_hash_dict( Returns: Dict[str, Any]: The canonical dictionary for hashing. """ - hash_dict: Dict[str, Any] = { + hash_dict: dict[str, Any] = { ComponentIdentifier.KEY_CLASS_NAME: class_name, ComponentIdentifier.KEY_CLASS_MODULE: class_module, } @@ -86,7 +86,7 @@ def _build_hash_dict( # Children contribute their hashes, not their full structure. if children: - children_hashes: Dict[str, Any] = {} + children_hashes: dict[str, Any] = {} for name, child in sorted(children.items()): if isinstance(child, ComponentIdentifier): children_hashes[name] = child.hash @@ -125,9 +125,9 @@ class ComponentIdentifier: #: Full module path (e.g., "pyrit.score.self_ask_scale_scorer"). class_module: str #: Behavioral parameters that affect output. - params: Dict[str, Any] = field(default_factory=dict) + params: dict[str, Any] = field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). - children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = field(default_factory=dict) + children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = field(default_factory=dict) #: Content-addressed SHA256 hash computed from class, params, and children. hash: str = field(init=False, compare=False) #: Version tag for storage. Not included in hash. @@ -170,8 +170,8 @@ def of( cls, obj: object, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Build a ComponentIdentifier from a live object instance. @@ -204,7 +204,7 @@ def of( ) @classmethod - def normalize(cls, value: Union[ComponentIdentifier, Dict[str, Any]]) -> ComponentIdentifier: + def normalize(cls, value: Union[ComponentIdentifier, dict[str, Any]]) -> ComponentIdentifier: """ Normalize a value to a ComponentIdentifier instance. @@ -233,7 +233,7 @@ def normalize(cls, value: Union[ComponentIdentifier, Dict[str, Any]]) -> Compone return cls.from_dict(value) raise TypeError(f"Expected ComponentIdentifier or dict, got {type(value).__name__}") - def to_dict(self, *, max_value_length: Optional[int] = None) -> Dict[str, Any]: + def to_dict(self, *, max_value_length: Optional[int] = None) -> dict[str, Any]: """ Serialize to a JSON-compatible dictionary for DB/JSONL storage. @@ -253,7 +253,7 @@ def to_dict(self, *, max_value_length: Optional[int] = None) -> Dict[str, Any]: Dict[str, Any]: JSON-serializable dictionary suitable for database storage or JSONL export. """ - result: Dict[str, Any] = { + result: dict[str, Any] = { self.KEY_CLASS_NAME: self.class_name, self.KEY_CLASS_MODULE: self.class_module, self.KEY_HASH: self.hash, @@ -264,7 +264,7 @@ def to_dict(self, *, max_value_length: Optional[int] = None) -> Dict[str, Any]: result[key] = self._truncate_value(value=value, max_length=max_value_length) if self.children: - serialized_children: Dict[str, Any] = {} + serialized_children: dict[str, Any] = {} for name, child in self.children.items(): if isinstance(child, ComponentIdentifier): serialized_children[name] = child.to_dict(max_value_length=max_value_length) @@ -293,7 +293,7 @@ def _truncate_value(*, value: Any, max_length: Optional[int]) -> Any: return value @classmethod - def from_dict(cls, data: Dict[str, Any]) -> ComponentIdentifier: + def from_dict(cls, data: dict[str, Any]) -> ComponentIdentifier: """ Deserialize from a stored dictionary. @@ -370,7 +370,7 @@ def get_child(self, key: str) -> Optional[ComponentIdentifier]: raise ValueError(f"Child '{key}' is a list of {len(child)} components. Use get_child_list() instead.") return child - def get_child_list(self, key: str) -> List[ComponentIdentifier]: + def get_child_list(self, key: str) -> list[ComponentIdentifier]: """ Get a list of children by key. @@ -390,8 +390,8 @@ def get_child_list(self, key: str) -> List[ComponentIdentifier]: @classmethod def _reconstruct_children( - cls, children_dict: Optional[Dict[str, Any]] - ) -> Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]: + cls, children_dict: Optional[dict[str, Any]] + ) -> dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]: """ Reconstruct child identifiers from raw dictionary data. @@ -402,7 +402,7 @@ def _reconstruct_children( Returns: Dict mapping child names to reconstructed ComponentIdentifier instances or lists thereof. """ - children: Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]] = {} + children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = {} if not children_dict or not isinstance(children_dict, dict): return children diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index e6d36120f2..6db0198562 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -3,9 +3,10 @@ import logging import struct +from collections.abc import MutableSequence, Sequence from contextlib import closing from datetime import datetime, timedelta, timezone -from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from azure.core.credentials import AccessToken from sqlalchemy import and_, create_engine, event, exists, text @@ -100,7 +101,7 @@ def __init__( self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() - super(AzureSQLMemory, self).__init__() + super().__init__() @staticmethod def _resolve_sas_token(env_var_name: str, passed_value: Optional[str] = None) -> Optional[str]: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 4836a8cd27..4fd22d5c85 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -7,10 +7,11 @@ import uuid import warnings import weakref +from collections.abc import MutableSequence, Sequence from contextlib import closing from datetime import datetime from pathlib import Path -from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from sqlalchemy import MetaData, and_, or_ from sqlalchemy.engine.base import Engine diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 5f127ee1d5..108a56625a 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -5,7 +5,7 @@ import logging import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from sqlalchemy import ( @@ -170,8 +170,8 @@ class PromptMemoryEntry(Base): timestamp = mapped_column(DateTime, nullable=False) labels: Mapped[dict[str, str]] = mapped_column(JSON) prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON) - targeted_harm_categories: Mapped[Optional[List[str]]] = mapped_column(JSON) - converter_identifiers: Mapped[Optional[List[Dict[str, str]]]] = mapped_column(JSON) + targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON) + converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON) prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) @@ -196,7 +196,7 @@ class PromptMemoryEntry(Base): # Nullable for backwards compatibility with existing databases pyrit_version = mapped_column(String, nullable=True) - scores: Mapped[List["ScoreEntry"]] = relationship( + scores: Mapped[list["ScoreEntry"]] = relationship( "ScoreEntry", primaryjoin="ScoreEntry.prompt_request_response_id == PromptMemoryEntry.original_prompt_id", back_populates="prompt_request_piece", @@ -254,7 +254,7 @@ def get_message_piece(self) -> MessagePiece: MessagePiece: The reconstructed message piece with all its data and scores. """ # Reconstruct ComponentIdentifiers with the stored pyrit_version - converter_ids: Optional[List[Union[ComponentIdentifier, Dict[str, str]]]] = None + converter_ids: Optional[list[Union[ComponentIdentifier, dict[str, str]]]] = None stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION if self.converter_identifiers: converter_ids = [ @@ -527,15 +527,15 @@ class SeedEntry(Base): data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) name = mapped_column(String, nullable=True) dataset_name = mapped_column(String, nullable=True) - harm_categories: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) description = mapped_column(String, nullable=True) - authors: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) - groups: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + authors: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + groups: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) source = mapped_column(String, nullable=True) date_added = mapped_column(DateTime, nullable=False) added_by = mapped_column(String, nullable=False) prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON, nullable=True) - parameters: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + parameters: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) prompt_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(CustomUUID, nullable=True) sequence: Mapped[Optional[int]] = mapped_column(INTEGER, nullable=True) role: Mapped[ChatMessageRole] = mapped_column(String, nullable=True) @@ -724,8 +724,8 @@ class AttackResultEntry(Base): ) outcome_reason = mapped_column(String, nullable=True) attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) - pruned_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) - adversarial_chat_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) # Version of PyRIT used when this attack result was created # Nullable for backwards compatibility with existing databases @@ -798,7 +798,7 @@ def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: return None @staticmethod - def filter_json_serializable_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + def filter_json_serializable_metadata(metadata: dict[str, Any]) -> dict[str, Any]: """ Filter a dictionary to only include JSON-serializable values. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 9c881acf48..6738abff6c 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -3,10 +3,11 @@ import logging import uuid +from collections.abc import MutableSequence, Sequence from contextlib import closing from datetime import datetime from pathlib import Path -from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine @@ -58,7 +59,7 @@ def __init__( verbose (bool): Whether to enable verbose logging. Defaults to False. """ - super(SQLiteMemory, self).__init__() + super().__init__() if db_path == ":memory:": self.db_path: Union[Path, str] = ":memory:" diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 5e23e4844a..3a11938edf 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,7 +4,7 @@ import base64 import json import os -from typing import Any, List, Union +from typing import Any, Union from pyrit.common import convert_local_image_to_data_url from pyrit.message_normalizer.message_normalizer import ( @@ -56,7 +56,7 @@ def __init__( self.use_developer_role = use_developer_role self.system_message_behavior = system_message_behavior - async def normalize_async(self, messages: List[Message]) -> List[ChatMessage]: + async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: """ Convert a list of Messages to a list of ChatMessages. @@ -78,7 +78,7 @@ async def normalize_async(self, messages: List[Message]) -> List[ChatMessage]: # Apply system message preprocessing processed_messages = await apply_system_message_behavior(messages, self.system_message_behavior) - chat_messages: List[ChatMessage] = [] + chat_messages: list[ChatMessage] = [] for message in processed_messages: pieces = message.message_pieces role: ChatMessageRole = pieces[0].api_role @@ -89,7 +89,7 @@ async def normalize_async(self, messages: List[Message]) -> List[ChatMessage]: # Use simple string for single text piece, otherwise use content list if len(pieces) == 1 and pieces[0].converted_value_data_type == "text": - content: Union[str, List[dict[str, Any]]] = pieces[0].converted_value + content: Union[str, list[dict[str, Any]]] = pieces[0].converted_value else: content = [await self._piece_to_content_dict_async(piece) for piece in pieces] @@ -97,7 +97,7 @@ async def normalize_async(self, messages: List[Message]) -> List[ChatMessage]: return chat_messages - async def normalize_string_async(self, messages: List[Message]) -> str: + async def normalize_string_async(self, messages: list[Message]) -> str: """ Convert a list of Messages to a JSON string representation. diff --git a/pyrit/message_normalizer/conversation_context_normalizer.py b/pyrit/message_normalizer/conversation_context_normalizer.py index 8238c1e2be..46e3c96308 100644 --- a/pyrit/message_normalizer/conversation_context_normalizer.py +++ b/pyrit/message_normalizer/conversation_context_normalizer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List from pyrit.message_normalizer.message_normalizer import MessageStringNormalizer from pyrit.models import Message, MessagePiece @@ -24,7 +23,7 @@ class ConversationContextNormalizer(MessageStringNormalizer): ... """ - async def normalize_string_async(self, messages: List[Message]) -> str: + async def normalize_string_async(self, messages: list[Message]) -> str: """ Normalize a list of messages into a turn-based context string. @@ -40,7 +39,7 @@ async def normalize_string_async(self, messages: List[Message]) -> str: if not messages: raise ValueError("Messages list cannot be empty") - context_parts: List[str] = [] + context_parts: list[str] = [] turn_number = 0 for message in messages: diff --git a/pyrit/message_normalizer/generic_system_squash.py b/pyrit/message_normalizer/generic_system_squash.py index 4bc2686c8d..56850289da 100644 --- a/pyrit/message_normalizer/generic_system_squash.py +++ b/pyrit/message_normalizer/generic_system_squash.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List from pyrit.message_normalizer.message_normalizer import MessageListNormalizer from pyrit.models import Message @@ -12,7 +11,7 @@ class GenericSystemSquashNormalizer(MessageListNormalizer[Message]): Normalizer that combines the first system message with the first user message using generic instruction tags. """ - async def normalize_async(self, messages: List[Message]) -> List[Message]: + async def normalize_async(self, messages: list[Message]) -> list[Message]: """ Return messages with the first system message combined into the first user message. diff --git a/pyrit/message_normalizer/message_normalizer.py b/pyrit/message_normalizer/message_normalizer.py index 0af9f26c14..66b534eca4 100644 --- a/pyrit/message_normalizer/message_normalizer.py +++ b/pyrit/message_normalizer/message_normalizer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc -from typing import Any, Generic, List, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from pyrit.models import Message @@ -36,7 +36,7 @@ class MessageListNormalizer(abc.ABC, Generic[T]): """ @abc.abstractmethod - async def normalize_async(self, messages: List[Message]) -> List[T]: + async def normalize_async(self, messages: list[Message]) -> list[T]: """ Normalize the list of messages into a list of items. @@ -47,7 +47,7 @@ async def normalize_async(self, messages: List[Message]) -> List[T]: A list of normalized items of type T. """ - async def normalize_to_dicts_async(self, messages: List[Message]) -> List[dict[str, Any]]: + async def normalize_to_dicts_async(self, messages: list[Message]) -> list[dict[str, Any]]: """ Normalize the list of messages into a list of dictionaries. @@ -71,7 +71,7 @@ class MessageStringNormalizer(abc.ABC): """ @abc.abstractmethod - async def normalize_string_async(self, messages: List[Message]) -> str: + async def normalize_string_async(self, messages: list[Message]) -> str: """ Normalize the list of messages into a string representation. @@ -83,7 +83,7 @@ async def normalize_string_async(self, messages: List[Message]) -> str: """ -async def apply_system_message_behavior(messages: List[Message], behavior: SystemMessageBehavior) -> List[Message]: +async def apply_system_message_behavior(messages: list[Message], behavior: SystemMessageBehavior) -> list[Message]: """ Apply a system message behavior to a list of messages. diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index 69d1d43512..db31126efd 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, cast +from typing import TYPE_CHECKING, ClassVar, Literal, Optional, cast from pyrit.common import get_non_required_value from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer @@ -48,7 +48,7 @@ class TokenizerTemplateNormalizer(MessageStringNormalizer): """ # Alias mappings for common HuggingFace models - MODEL_ALIASES: ClassVar[Dict[str, TokenizerModelConfig]] = { + MODEL_ALIASES: ClassVar[dict[str, TokenizerModelConfig]] = { # No authentication required "chatml": TokenizerModelConfig( model_name="HuggingFaceH4/zephyr-7b-beta", @@ -187,7 +187,7 @@ def from_model( ), ) - async def normalize_string_async(self, messages: List[Message]) -> str: + async def normalize_string_async(self, messages: list[Message]) -> str: """ Apply the chat template stored in the tokenizer to a list of messages. diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 18b3150331..e40d0d228b 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Optional, TypeVar from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.conversation_reference import ConversationReference, ConversationType @@ -73,7 +73,7 @@ class AttackResult(StrategyResult): related_conversations: set[ConversationReference] = field(default_factory=set) # Arbitrary metadata - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) def get_conversations_by_type(self, conversation_type: ConversationType) -> list[ConversationReference]: """ diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 09f24e34cb..5283eb6e6e 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -33,7 +33,7 @@ def data_serializer_factory( value: Optional[str] = None, extension: Optional[str] = None, category: AllowedCategories, -) -> "DataTypeSerializer": +) -> DataTypeSerializer: """ Create a DataTypeSerializer instance. diff --git a/pyrit/models/harm_definition.py b/pyrit/models/harm_definition.py index 264e202350..65cd2f6d61 100644 --- a/pyrit/models/harm_definition.py +++ b/pyrit/models/harm_definition.py @@ -11,7 +11,7 @@ import re from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import yaml @@ -54,7 +54,7 @@ class HarmDefinition: version: str category: str - scale_descriptions: List[ScaleDescription] = field(default_factory=list) + scale_descriptions: list[ScaleDescription] = field(default_factory=list) source_path: Optional[str] = field(default=None, kw_only=True) def get_scale_description(self, score_value: str) -> Optional[str]: @@ -140,7 +140,7 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": ) try: - with open(resolved_path, "r", encoding="utf-8") as f: + with open(resolved_path, encoding="utf-8") as f: data = yaml.safe_load(f) except yaml.YAMLError as e: raise ValueError(f"Invalid YAML in harm definition file {resolved_path}: {e}") @@ -178,7 +178,7 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": ) -def get_all_harm_definitions() -> Dict[str, HarmDefinition]: +def get_all_harm_definitions() -> dict[str, HarmDefinition]: """ Load all harm definitions from the standard harm_definition directory. @@ -194,7 +194,7 @@ def get_all_harm_definitions() -> Dict[str, HarmDefinition]: ValueError: If any YAML file in the directory is invalid. """ - harm_definitions: Dict[str, HarmDefinition] = {} + harm_definitions: dict[str, HarmDefinition] = {} if not HARM_DEFINITION_PATH.exists(): logger.warning(f"Harm definition directory does not exist: {HARM_DEFINITION_PATH}") diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py index f2ed20032d..ad28b1677a 100644 --- a/pyrit/models/json_response_config.py +++ b/pyrit/models/json_response_config.py @@ -5,7 +5,7 @@ import json from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Optional # Would prefer StrEnum, but.... Python 3.10 _METADATAKEYS = { @@ -28,12 +28,12 @@ class _JsonResponseConfig: """ enabled: bool = False - schema: Optional[Dict[str, Any]] = None + schema: Optional[dict[str, Any]] = None schema_name: str = "CustomSchema" strict: bool = True @classmethod - def from_metadata(cls, *, metadata: Optional[Dict[str, Any]]) -> _JsonResponseConfig: + def from_metadata(cls, *, metadata: Optional[dict[str, Any]]) -> _JsonResponseConfig: if not metadata: return cls(enabled=False) diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 340e052466..91f604f2bc 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -6,8 +6,9 @@ import copy import uuid import warnings +from collections.abc import MutableSequence, Sequence from datetime import datetime -from typing import Dict, MutableSequence, Optional, Sequence, Union +from typing import Optional, Union from pyrit.common.utils import combine_dict from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError @@ -365,7 +366,7 @@ def from_prompt( *, prompt: str, role: ChatMessageRole, - prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, + prompt_metadata: Optional[dict[str, Union[str, int]]] = None, ) -> Message: """ Build a single-piece message from prompt text. @@ -541,7 +542,7 @@ def construct_response_from_request( request: MessagePiece, response_text_pieces: list[str], response_type: PromptDataType = "text", - prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, + prompt_metadata: Optional[dict[str, Union[str, int]]] = None, error: PromptResponseError = "none", ) -> Message: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index fdaf4e8d24..9f1a0ea30e 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union, get_args +from typing import Any, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.identifiers.component_identifier import ComponentIdentifier @@ -35,20 +35,20 @@ def __init__( id: Optional[uuid.UUID | str] = None, conversation_id: Optional[str] = None, sequence: int = -1, - labels: Optional[Dict[str, str]] = None, - prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, - converter_identifiers: Optional[List[Union[ComponentIdentifier, Dict[str, str]]]] = None, - prompt_target_identifier: Optional[Union[ComponentIdentifier, Dict[str, Any]]] = None, - attack_identifier: Optional[Union[ComponentIdentifier, Dict[str, str]]] = None, - scorer_identifier: Optional[Union[ComponentIdentifier, Dict[str, str]]] = None, + labels: Optional[dict[str, str]] = None, + prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + converter_identifiers: Optional[list[Union[ComponentIdentifier, dict[str, str]]]] = None, + prompt_target_identifier: Optional[Union[ComponentIdentifier, dict[str, Any]]] = None, + attack_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None, + scorer_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", converted_value_data_type: Optional[PromptDataType] = None, response_error: PromptResponseError = "none", originator: Originator = "undefined", original_prompt_id: Optional[uuid.UUID] = None, timestamp: Optional[datetime] = None, - scores: Optional[List[Score]] = None, - targeted_harm_categories: Optional[List[str]] = None, + scores: Optional[list[Score]] = None, + targeted_harm_categories: Optional[list[str]] = None, ): """ Initialize a MessagePiece. @@ -111,7 +111,7 @@ def __init__( self.prompt_metadata = prompt_metadata or {} # Handle converter_identifiers: normalize to ComponentIdentifier (handles dict with deprecation warning) - self.converter_identifiers: List[ComponentIdentifier] = ( + self.converter_identifiers: list[ComponentIdentifier] = ( [ComponentIdentifier.normalize(conv_id) for conv_id in converter_identifiers] if converter_identifiers else [] diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 07481b4a97..52669fc033 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -4,7 +4,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import pyrit from pyrit.models import AttackOutcome, AttackResult @@ -60,9 +60,9 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: Union[Dict[str, Any], "ComponentIdentifier"], - attack_results: dict[str, List[AttackResult]], - objective_scorer_identifier: Union[Dict[str, Any], "ComponentIdentifier"], + objective_target_identifier: Union[dict[str, Any], "ComponentIdentifier"], + attack_results: dict[str, list[AttackResult]], + objective_scorer_identifier: Union[dict[str, Any], "ComponentIdentifier"], scenario_run_state: ScenarioRunState = "CREATED", labels: Optional[dict[str, str]] = None, completion_time: Optional[datetime] = None, @@ -115,7 +115,7 @@ def __init__( self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries - def get_strategies_used(self) -> List[str]: + def get_strategies_used(self) -> list[str]: """ Get the list of strategies used in this scenario. @@ -125,7 +125,7 @@ def get_strategies_used(self) -> List[str]: """ return list(self.attack_results.keys()) - def get_objectives(self, *, atomic_attack_name: Optional[str] = None) -> List[str]: + def get_objectives(self, *, atomic_attack_name: Optional[str] = None) -> list[str]: """ Get the list of unique objectives for this scenario. @@ -137,8 +137,8 @@ def get_objectives(self, *, atomic_attack_name: Optional[str] = None) -> List[st List[str]: Deduplicated list of objectives. """ - objectives: List[str] = [] - strategies_to_process: List[List[AttackResult]] + objectives: list[str] = [] + strategies_to_process: list[list[AttackResult]] if not atomic_attack_name: # Include all atomic attacks diff --git a/pyrit/models/score.py b/pyrit/models/score.py index adf64cc043..6e1b2ab79b 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -6,7 +6,7 @@ import uuid from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args if TYPE_CHECKING: from pyrit.identifiers.component_identifier import ComponentIdentifier @@ -29,16 +29,16 @@ class Score: score_type: ScoreType # The harms categories (e.g. ["hate", "violence"]) – can be multiple - score_category: Optional[List[str]] + score_category: Optional[list[str]] # Extra data the scorer provides around the rationale of the score score_rationale: str # Custom metadata a scorer might use. This can vary by scorer. - score_metadata: Optional[Dict[str, Union[str, int, float]]] + score_metadata: Optional[dict[str, Union[str, int, float]]] # The identifier of the scorer class, including relevant information - scorer_class_identifier: "ComponentIdentifier" + scorer_class_identifier: ComponentIdentifier # This is the ID of the MessagePiece that the score is scoring # Note a scorer can generate an additional request. This is NOT that, but @@ -60,9 +60,9 @@ def __init__( score_rationale: str, message_piece_id: str | uuid.UUID, id: Optional[uuid.UUID | str] = None, - score_category: Optional[List[str]] = None, - score_metadata: Optional[Dict[str, Union[str, int, float]]] = None, - scorer_class_identifier: Union["ComponentIdentifier", Dict[str, Any]], + score_category: Optional[list[str]] = None, + score_metadata: Optional[dict[str, Union[str, int, float]]] = None, + scorer_class_identifier: Union[ComponentIdentifier, dict[str, Any]], timestamp: Optional[datetime] = None, objective: Optional[str] = None, ): @@ -155,7 +155,7 @@ def validate(self, scorer_type: str, score_value: str) -> None: except ValueError: raise ValueError(f"Float scale scorers require a numeric score value. Got {score_value}") - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Convert this score to a dictionary. @@ -205,10 +205,10 @@ class UnvalidatedScore: raw_score_value: str score_value_description: str - score_category: Optional[List[str]] + score_category: Optional[list[str]] score_rationale: str - score_metadata: Optional[Dict[str, Union[str, int, float]]] - scorer_class_identifier: "ComponentIdentifier" + score_metadata: Optional[dict[str, Union[str, int, float]]] + scorer_class_identifier: ComponentIdentifier message_piece_id: uuid.UUID | str objective: Optional[str] id: Optional[uuid.UUID | str] = None diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index ffd5e5962b..4096037174 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -13,10 +13,11 @@ import logging import re import uuid +from collections.abc import Iterator, Sequence from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, Iterator, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from jinja2 import BaseLoader, Environment, StrictUndefined, Template, Undefined @@ -115,7 +116,7 @@ class Seed(YamlLoadable): added_by: Optional[str] = None # Arbitrary metadata that can be attached to the prompt - metadata: Optional[Dict[str, Union[str, int]]] = field(default_factory=lambda: {}) + metadata: Optional[dict[str, Union[str, int]]] = field(default_factory=lambda: {}) # Unique identifier for the prompt group prompt_group_id: Optional[uuid.UUID] = None diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index 62112acb4c..3a71c75625 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -9,7 +9,8 @@ from __future__ import annotations -from typing import Any, Dict, Sequence, Union +from collections.abc import Sequence +from typing import Any, Union from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup @@ -30,7 +31,7 @@ class SeedAttackGroup(SeedGroup): def __init__( self, *, - seeds: Sequence[Union[Seed, Dict[str, Any]]], + seeds: Sequence[Union[Seed, dict[str, Any]]], ): """ Initialize a SeedAttackGroup. diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index ec5db2b822..0d7dd91693 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,7 +11,8 @@ from __future__ import annotations -from typing import Any, Dict, Sequence, Union +from collections.abc import Sequence +from typing import Any, Union from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup @@ -31,7 +32,7 @@ class SeedAttackTechniqueGroup(SeedGroup): def __init__( self, *, - seeds: Sequence[Union[Seed, Dict[str, Any]]], + seeds: Sequence[Union[Seed, dict[str, Any]]], ): """ Initialize a SeedAttackTechniqueGroup. diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 3f78c5c6ac..6ae4ff91ce 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -12,8 +12,9 @@ import uuid import warnings from collections import defaultdict +from collections.abc import Sequence from datetime import datetime -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union from pydantic.types import PositiveInt @@ -49,12 +50,12 @@ class SeedDataset(YamlLoadable): added_by: Optional[str] # Now the actual prompts - seeds: Sequence["Seed"] + seeds: Sequence[Seed] def __init__( self, *, - seeds: Optional[Union[Sequence[Dict[str, Any]], Sequence[Seed]]] = None, + seeds: Optional[Union[Sequence[dict[str, Any]], Sequence[Seed]]] = None, data_type: Optional[PromptDataType] = "text", name: Optional[str] = None, dataset_name: Optional[str] = None, @@ -257,7 +258,7 @@ def get_random_values( return random.sample(prompts, min(len(prompts), number)) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> SeedDataset: + def from_dict(cls, data: dict[str, Any]) -> SeedDataset: """ Build a SeedDataset by merging top-level defaults into each item in `seeds`. @@ -362,7 +363,7 @@ def group_seed_prompts_by_prompt_group_id(seeds: Sequence[Seed]) -> Sequence[See """ # Group seeds by `prompt_group_id` - grouped_seeds: Dict[uuid.UUID, list[Seed]] = defaultdict(list) + grouped_seeds: dict[uuid.UUID, list[Seed]] = defaultdict(list) for seed in seeds: if seed.prompt_group_id: grouped_seeds[seed.prompt_group_id].append(seed) diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index e66d582288..cff0ca81ee 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -14,7 +14,8 @@ import uuid import warnings from collections import defaultdict -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from pyrit.common.yaml_loadable import YamlLoadable from pyrit.models.message import Message @@ -40,12 +41,12 @@ class SeedGroup(YamlLoadable): All prompts in the group share the same `prompt_group_id`. """ - seeds: List[Seed] + seeds: list[Seed] def __init__( self, *, - seeds: Sequence[Union[Seed, Dict[str, Any]]], + seeds: Sequence[Union[Seed, dict[str, Any]]], ): """ Initialize a SeedGroup. @@ -279,7 +280,7 @@ def objective(self) -> Optional[SeedObjective]: return self._get_objective() @property - def harm_categories(self) -> List[str]: + def harm_categories(self) -> list[str]: """ Returns a deduplicated list of all harm categories from all seeds. @@ -287,7 +288,7 @@ def harm_categories(self) -> List[str]: List of harm categories with duplicates removed. """ - categories: List[str] = [] + categories: list[str] = [] for seed in self.seeds: if seed.harm_categories: categories.extend(seed.harm_categories) @@ -312,7 +313,7 @@ def has_simulated_conversation(self) -> bool: # ========================================================================= @property - def prepended_conversation(self) -> Optional[List[Message]]: + def prepended_conversation(self) -> Optional[list[Message]]: """ Returns Messages that should be prepended as conversation history. @@ -369,7 +370,7 @@ def next_message(self) -> Optional[Message]: return messages[0] if messages else None @property - def user_messages(self) -> List[Message]: + def user_messages(self) -> list[Message]: """ Returns all prompts as user Messages, one per sequence. @@ -399,7 +400,7 @@ def _get_last_sequence_role(self) -> Optional[str]: return last_sequence_prompts[0].role if last_sequence_prompts else None - def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> List[Message]: + def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> list[Message]: """ Convert a sequence of SeedPrompts to Messages. diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 048098941a..7167ec2e28 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -10,9 +10,10 @@ import logging import os import uuid +from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Optional, Union from tinytag import TinyTag @@ -114,7 +115,7 @@ def from_yaml_with_required_parameters( template_path: Union[str, Path], required_parameters: list[str], error_message: Optional[str] = None, - ) -> "SeedPrompt": + ) -> SeedPrompt: """ Load a Seed from a YAML file and validate that it contains specific parameters. @@ -141,11 +142,11 @@ def from_yaml_with_required_parameters( @staticmethod def from_messages( - messages: list["Message"], + messages: list[Message], *, starting_sequence: int = 0, prompt_group_id: Optional[uuid.UUID] = None, - ) -> list["SeedPrompt"]: + ) -> list[SeedPrompt]: """ Convert a list of Messages to a list of SeedPrompts. diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index f6ecbbd73c..019e842faa 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -18,7 +18,7 @@ import json import logging from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import pyrit from pyrit.common.path import EXECUTOR_SIMULATED_TARGET_PATH @@ -146,7 +146,7 @@ def _compute_value(self) -> str: return json.dumps(config, sort_keys=True, separators=(",", ":")) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SeedSimulatedConversation": + def from_dict(cls, data: dict[str, Any]) -> SeedSimulatedConversation: """ Create a SeedSimulatedConversation from a dictionary, typically from YAML. @@ -183,7 +183,7 @@ def from_yaml_with_required_parameters( template_path: Union[str, Path], required_parameters: list[str], error_message: Optional[str] = None, - ) -> "SeedSimulatedConversation": + ) -> SeedSimulatedConversation: """ Load a SeedSimulatedConversation from a YAML file and validate required parameters. @@ -209,7 +209,7 @@ def from_yaml_with_required_parameters( return instance - def get_identifier(self) -> Dict[str, Any]: + def get_identifier(self) -> dict[str, Any]: """ Get an identifier dict capturing this configuration for comparison/storage. diff --git a/pyrit/prompt_converter/audio_frequency_converter.py b/pyrit/prompt_converter/audio_frequency_converter.py index 733c846421..867e1e5738 100644 --- a/pyrit/prompt_converter/audio_frequency_converter.py +++ b/pyrit/prompt_converter/audio_frequency_converter.py @@ -97,11 +97,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "audi converted_bytes = bytes_io.getvalue() await audio_serializer.save_data(data=converted_bytes) audio_serializer_file = str(audio_serializer.value) - logger.info( - "Speech synthesized for text [{}], and the audio was saved to [{}]".format( - prompt, audio_serializer_file - ) - ) + logger.info(f"Speech synthesized for text [{prompt}], and the audio was saved to [{audio_serializer_file}]") except Exception as e: logger.error("Failed to convert prompt to audio: %s", str(e)) diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index 0beaa52930..1d890a0b95 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -172,10 +172,10 @@ def recognize_audio(self, audio_bytes: bytes) -> str: # Connect callbacks to the events fired by the speech recognizer speech_recognizer.recognized.connect(lambda evt: self.transcript_cb(evt, transcript=transcribed_text)) - speech_recognizer.recognizing.connect(lambda evt: logger.info("RECOGNIZING: {}".format(evt))) - speech_recognizer.recognized.connect(lambda evt: logger.info("RECOGNIZED: {}".format(evt))) - speech_recognizer.session_started.connect(lambda evt: logger.info("SESSION STARTED: {}".format(evt))) - speech_recognizer.session_stopped.connect(lambda evt: logger.info("SESSION STOPPED: {}".format(evt))) + speech_recognizer.recognizing.connect(lambda evt: logger.info(f"RECOGNIZING: {evt}")) + speech_recognizer.recognized.connect(lambda evt: logger.info(f"RECOGNIZED: {evt}")) + speech_recognizer.session_started.connect(lambda evt: logger.info(f"SESSION STARTED: {evt}")) + speech_recognizer.session_stopped.connect(lambda evt: logger.info(f"SESSION STOPPED: {evt}")) # Stop continuous recognition when stopped or canceled event is fired speech_recognizer.canceled.connect(lambda evt: self.stop_cb(evt, recognizer=speech_recognizer)) speech_recognizer.session_stopped.connect(lambda evt: self.stop_cb(evt, recognizer=speech_recognizer)) @@ -200,7 +200,7 @@ def transcript_cb(self, evt: Any, transcript: list[str]) -> None: evt (speechsdk.SpeechRecognitionEventArgs): Event. transcript (list): List to store transcribed text. """ - logger.info("RECOGNIZED: {}".format(evt.result.text)) + logger.info(f"RECOGNIZED: {evt.result.text}") transcript.append(evt.result.text) def stop_cb(self, evt: Any, recognizer: Any) -> None: @@ -223,13 +223,13 @@ def stop_cb(self, evt: Any, recognizer: Any) -> None: ) raise e - logger.info("CLOSING on {}".format(evt)) + logger.info(f"CLOSING on {evt}") recognizer.stop_continuous_recognition_async() self.done = True if evt.result.reason == speechsdk.ResultReason.Canceled: cancellation_details = evt.result.cancellation_details - logger.info("Speech recognition canceled: {}".format(cancellation_details.reason)) + logger.info(f"Speech recognition canceled: {cancellation_details.reason}") if cancellation_details.reason == speechsdk.CancellationReason.Error: - logger.error("Error details: {}".format(cancellation_details.error_details)) + logger.error(f"Error details: {cancellation_details.error_details}") elif cancellation_details.reason == speechsdk.CancellationReason.EndOfStream: logger.info("End of audio stream detected.") diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index e0ea7dbd3d..e602def6dd 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -167,18 +167,16 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text await audio_serializer.save_data(audio_data) audio_serializer_file = str(audio_serializer.value) logger.info( - "Speech synthesized for text [{}], and the audio was saved to [{}]".format( - prompt, audio_serializer_file - ) + f"Speech synthesized for text [{prompt}], and the audio was saved to [{audio_serializer_file}]" ) elif result.reason == speechsdk.ResultReason.Canceled: cancellation_details = result.cancellation_details - logger.info("Speech synthesis canceled: {}".format(cancellation_details.reason)) + logger.info(f"Speech synthesis canceled: {cancellation_details.reason}") if cancellation_details.reason == speechsdk.CancellationReason.Error: - logger.error("Error details: {}".format(cancellation_details.error_details)) + logger.error(f"Error details: {cancellation_details.error_details}") raise RuntimeError( - "Speech synthesis canceled: {}".format(cancellation_details.reason) - + "Error details: {}".format(cancellation_details.error_details) + f"Speech synthesis canceled: {cancellation_details.reason}" + + f"Error details: {cancellation_details.error_details}" ) except Exception as e: logger.error("Failed to convert prompt to audio: %s", str(e)) diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index a73ac431eb..7f41bbecef 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -6,7 +6,8 @@ import pathlib import re import textwrap -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier diff --git a/pyrit/prompt_converter/colloquial_wordswap_converter.py b/pyrit/prompt_converter/colloquial_wordswap_converter.py index 481013c85c..ce17fe73a6 100644 --- a/pyrit/prompt_converter/colloquial_wordswap_converter.py +++ b/pyrit/prompt_converter/colloquial_wordswap_converter.py @@ -3,7 +3,7 @@ import random import re -from typing import Dict, List, Optional +from typing import Optional from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType @@ -19,7 +19,7 @@ class ColloquialWordswapConverter(PromptConverter): SUPPORTED_OUTPUT_TYPES = ("text",) def __init__( - self, deterministic: bool = False, custom_substitutions: Optional[Dict[str, List[str]]] = None + self, deterministic: bool = False, custom_substitutions: Optional[dict[str, list[str]]] = None ) -> None: """ Initialize the converter with optional deterministic mode and custom substitutions. diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py index 54508d34d8..5192a71c65 100644 --- a/pyrit/prompt_converter/insert_punctuation_converter.py +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -4,7 +4,7 @@ import random import re import string -from typing import List, Optional +from typing import Optional from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType @@ -59,7 +59,7 @@ def _build_identifier(self) -> ComponentIdentifier: } ) - def _is_valid_punctuation(self, punctuation_list: List[str]) -> bool: + def _is_valid_punctuation(self, punctuation_list: list[str]) -> bool: """ Check if all items in the list are valid punctuation characters in string.punctuation. Space, letters, numbers, double punctuations are all invalid. @@ -73,7 +73,7 @@ def _is_valid_punctuation(self, punctuation_list: List[str]) -> bool: return all(str in string.punctuation for str in punctuation_list) async def convert_async( - self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[List[str]] = None + self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[list[str]] = None ) -> ConverterResult: """ Convert the given prompt by inserting punctuation. @@ -105,7 +105,7 @@ async def convert_async( modified_prompt = self._insert_punctuation(prompt, punctuation_list) return ConverterResult(output_text=modified_prompt, output_type="text") - def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: + def _insert_punctuation(self, prompt: str, punctuation_list: list[str]) -> str: """ Insert punctuation into the prompt. @@ -134,7 +134,7 @@ def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: return self._insert_within_words(prompt, num_insertions, punctuation_list) def _insert_between_words( - self, words: List[str], word_indices: List[int], num_insertions: int, punctuation_list: List[str] + self, words: list[str], word_indices: list[int], num_insertions: int, punctuation_list: list[str] ) -> str: """ Insert punctuation between words in the prompt. @@ -160,7 +160,7 @@ def _insert_between_words( # Join the words list and return a modified prompt return "".join(words).strip() - def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_list: List[str]) -> str: + def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_list: list[str]) -> str: """ Insert punctuation at any indices in the prompt, can insert into a word. diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 4cb5900fb7..e36b064258 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -5,7 +5,7 @@ import hashlib from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional from pypdf import PageObject, PdfReader, PdfWriter from reportlab.lib.units import mm @@ -48,7 +48,7 @@ def __init__( column_width: int = 0, row_height: int = 10, existing_pdf: Optional[Path] = None, - injection_items: Optional[List[Dict[str, Any]]] = None, + injection_items: Optional[list[dict[str, Any]]] = None, ) -> None: """ Initialize the converter with the specified parameters. diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index c776297179..8f2c7002ad 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,7 +6,7 @@ import inspect import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union, get_args +from typing import Any, Optional, Union, get_args from pyrit import prompt_converter from pyrit.identifiers import ComponentIdentifier, Identifiable @@ -181,8 +181,8 @@ def _build_identifier(self) -> ComponentIdentifier: def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct and return the converter identifier. @@ -203,7 +203,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this converter. """ - all_params: Dict[str, Any] = { + all_params: dict[str, Any] = { "supported_input_types": self.SUPPORTED_INPUT_TYPES, "supported_output_types": self.SUPPORTED_OUTPUT_TYPES, } diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 185f9d4ef5..74953c2603 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -4,7 +4,7 @@ import logging import random from pathlib import Path -from typing import List, Optional +from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH, DATASETS_PATH @@ -37,7 +37,7 @@ def __init__( *, converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, - languages: Optional[List[str]] = None, + languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, ): """ diff --git a/pyrit/prompt_converter/text_selection_strategy.py b/pyrit/prompt_converter/text_selection_strategy.py index 4afbb3ac6e..18c7f3c70a 100644 --- a/pyrit/prompt_converter/text_selection_strategy.py +++ b/pyrit/prompt_converter/text_selection_strategy.py @@ -4,7 +4,8 @@ import abc import random import re -from typing import List, Optional, Pattern, Union +from re import Pattern +from typing import Optional, Union class TextSelectionStrategy(abc.ABC): @@ -76,7 +77,7 @@ class WordSelectionStrategy(TextSelectionStrategy): """ @abc.abstractmethod - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select word indices to be converted. @@ -400,7 +401,7 @@ class WordIndexSelectionStrategy(WordSelectionStrategy): Selects words based on their indices in the word list. """ - def __init__(self, *, indices: List[int]) -> None: + def __init__(self, *, indices: list[int]) -> None: """ Initialize the word index selection strategy. @@ -409,7 +410,7 @@ def __init__(self, *, indices: List[int]) -> None: """ self._indices = indices - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select words at the specified indices. @@ -439,7 +440,7 @@ class WordKeywordSelectionStrategy(WordSelectionStrategy): Selects words that match specific keywords. """ - def __init__(self, *, keywords: List[str], case_sensitive: bool = True) -> None: + def __init__(self, *, keywords: list[str], case_sensitive: bool = True) -> None: """ Initialize the word keyword selection strategy. @@ -450,7 +451,7 @@ def __init__(self, *, keywords: List[str], case_sensitive: bool = True) -> None: self._keywords = keywords self._case_sensitive = case_sensitive - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select words that match the keywords. @@ -491,7 +492,7 @@ def __init__(self, *, proportion: float, seed: Optional[int] = None) -> None: self._proportion = proportion self._seed = seed - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select a random proportion of words. @@ -525,7 +526,7 @@ def __init__(self, *, pattern: Union[str, Pattern[str]]) -> None: """ self._pattern = re.compile(pattern) if isinstance(pattern, str) else pattern - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select words that match the regex pattern. @@ -569,7 +570,7 @@ def __init__(self, *, start_proportion: float, end_proportion: float) -> None: self._start_proportion = start_proportion self._end_proportion = end_proportion - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select words based on the relative position. @@ -594,7 +595,7 @@ class AllWordsSelectionStrategy(WordSelectionStrategy): Selects all words (default strategy). """ - def select_words(self, *, words: List[str]) -> List[int]: + def select_words(self, *, words: list[str]) -> list[int]: """ Select all words. diff --git a/pyrit/prompt_converter/token_smuggling/base.py b/pyrit/prompt_converter/token_smuggling/base.py index d05e72765a..480be37e79 100644 --- a/pyrit/prompt_converter/token_smuggling/base.py +++ b/pyrit/prompt_converter/token_smuggling/base.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Literal, Tuple +from typing import Literal from pyrit.identifiers import ComponentIdentifier from pyrit.models.literals import PromptDataType @@ -98,7 +98,7 @@ def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "text" @abc.abstractmethod - def encode_message(self, *, message: str) -> Tuple[str, str]: + def encode_message(self, *, message: str) -> tuple[str, str]: """ Encode the given message. diff --git a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py index 4272e9a1ee..15719eba2d 100644 --- a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional, Tuple +from typing import Literal, Optional from pyrit.identifiers import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -58,7 +58,7 @@ def _build_identifier(self) -> ComponentIdentifier: } ) - def encode_message(self, message: str) -> Tuple[str, str]: + def encode_message(self, message: str) -> tuple[str, str]: """ Encode the message using Sneaky Bits mode. diff --git a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py index 7ccfbdec25..cf7ceccdf6 100644 --- a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional, Tuple +from typing import Literal, Optional from pyrit.identifiers import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -67,7 +67,7 @@ def _build_identifier(self) -> ComponentIdentifier: } ) - def encode_message(self, message: str) -> Tuple[str, str]: + def encode_message(self, message: str) -> tuple[str, str]: """ Encode the message using Unicode variation selectors. @@ -144,7 +144,7 @@ def decode_message(self, message: str) -> str: return decoded_text # Extension of Paul Butler's method - def encode_visible_hidden(self, visible: str, hidden: str) -> Tuple[str, str]: + def encode_visible_hidden(self, visible: str, hidden: str) -> tuple[str, str]: """ Combine visible text with hidden text by encoding the hidden text using ``variation_selector_smuggler`` mode. @@ -163,7 +163,7 @@ def encode_visible_hidden(self, visible: str, hidden: str) -> Tuple[str, str]: return summary, combined # Extension of Paul Butler's method - def decode_visible_hidden(self, combined: str) -> Tuple[str, str]: + def decode_visible_hidden(self, combined: str) -> tuple[str, str]: """ Extract the visible text and decodes the hidden text from a combined string. diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index b3df57f108..2580b7fe64 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -5,7 +5,6 @@ import logging from io import BytesIO from pathlib import Path -from typing import Tuple import numpy from PIL import Image @@ -129,7 +128,7 @@ def __init__( self, *, benign_image_path: Path, - size: Tuple[int, int] = (150, 150), + size: tuple[int, int] = (150, 150), steps: int = 1500, learning_rate: float = 0.001, convergence_threshold: float = 1e-6, diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index fdd5cbed82..48fea1d2ba 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Optional from docx import Document @@ -180,7 +180,7 @@ def _prepare_content(self, prompt: str) -> str: if self._prompt_template: logger.debug(f"Preparing Word content with template: {self._prompt_template.value}") try: - dynamic_data: Dict[str, Any] = ast.literal_eval(prompt) + dynamic_data: dict[str, Any] = ast.literal_eval(prompt) if not isinstance(dynamic_data, dict): raise ValueError("Prompt must be a dictionary-compatible object after parsing.") diff --git a/pyrit/prompt_normalizer/prompt_converter_configuration.py b/pyrit/prompt_normalizer/prompt_converter_configuration.py index 9894a60643..cb9ae55425 100644 --- a/pyrit/prompt_normalizer/prompt_converter_configuration.py +++ b/pyrit/prompt_normalizer/prompt_converter_configuration.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import List, Optional +from typing import Optional from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter @@ -19,11 +19,11 @@ class PromptConverterConfiguration: """ converters: list[PromptConverter] - indexes_to_apply: Optional[List[int]] = None - prompt_data_types_to_apply: Optional[List[PromptDataType]] = None + indexes_to_apply: Optional[list[int]] = None + prompt_data_types_to_apply: Optional[list[PromptDataType]] = None @classmethod - def from_converters(cls, *, converters: List[PromptConverter]) -> List["PromptConverterConfiguration"]: + def from_converters(cls, *, converters: list[PromptConverter]) -> list["PromptConverterConfiguration"]: """ Convert a list of converters into a list of PromptConverterConfiguration objects. Each converter gets its own configuration with default settings. diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 1b75dfe8a4..01c8dab711 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -5,7 +5,7 @@ import copy import logging import traceback -from typing import Any, List, Optional +from typing import Any, Optional from uuid import uuid4 from pyrit.exceptions import ( @@ -175,7 +175,7 @@ async def send_prompt_batch_to_target_async( list[Message]: A list of Message objects representing the responses received for each prompt. """ - batch_items: List[List[Any]] = [ + batch_items: list[list[Any]] = [ [request.message for request in requests], [request.request_converter_configurations for request in requests], [request.response_converter_configurations for request in requests], diff --git a/pyrit/prompt_target/batch_helper.py b/pyrit/prompt_target/batch_helper.py index bb5a2b2062..95ec6809fb 100644 --- a/pyrit/prompt_target/batch_helper.py +++ b/pyrit/prompt_target/batch_helper.py @@ -2,12 +2,13 @@ # Licensed under the MIT license. import asyncio -from typing import Any, Callable, Generator, List, Optional, Sequence +from collections.abc import Callable, Generator, Sequence +from typing import Any, Optional from pyrit.prompt_target.common.prompt_target import PromptTarget -def _get_chunks(*args: Sequence[Any], batch_size: int) -> Generator[List[Sequence[Any]], None, None]: +def _get_chunks(*args: Sequence[Any], batch_size: int) -> Generator[list[Sequence[Any]], None, None]: """ Split provided lists into chunks of specified batch size. diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 902d3c10bd..dce3ef0050 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface @@ -24,7 +24,7 @@ class PromptTarget(Identifiable): #: A list of PromptConverters that are supported by the prompt target. #: An empty list implies that the prompt target supports all converters. - supported_converters: List[Any] + supported_converters: list[Any] _identifier: Optional[ComponentIdentifier] = None @@ -96,8 +96,8 @@ def dispose_db_engine(self) -> None: def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the target identifier. @@ -123,7 +123,7 @@ def _create_identifier( # Late import to avoid circular dependency (PromptChatTarget inherits from PromptTarget) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget - all_params: Dict[str, Any] = { + all_params: dict[str, Any] = { "endpoint": self._endpoint, "model_name": model_name, "max_requests_per_minute": self._max_requests_per_minute, diff --git a/pyrit/prompt_target/common/utils.py b/pyrit/prompt_target/common/utils.py index 7054883d0a..ca0a4ca7da 100644 --- a/pyrit/prompt_target/common/utils.py +++ b/pyrit/prompt_target/common/utils.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import asyncio -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from pyrit.exceptions import PyritException diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 50ca68a886..2a85830fc2 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -5,7 +5,8 @@ import json import logging import re -from typing import Any, Callable, Dict, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import Any, Optional import httpx @@ -210,7 +211,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: if cleanup_client: await client.aclose() - def parse_raw_http_request(self, http_request: str) -> tuple[Dict[str, str], RequestBody, str, str, str]: + def parse_raw_http_request(self, http_request: str) -> tuple[dict[str, str], RequestBody, str, str, str]: """ Parse the HTTP request string into a dictionary of headers. @@ -228,7 +229,7 @@ def parse_raw_http_request(self, http_request: str) -> tuple[Dict[str, str], Req Raises: ValueError: If the HTTP request line is invalid. """ - headers_dict: Dict[str, str] = {} + headers_dict: dict[str, str] = {} if self._client: headers_dict = dict(self._client.headers.copy()) if not http_request: @@ -276,7 +277,7 @@ def parse_raw_http_request(self, http_request: str) -> tuple[Dict[str, str], Req def _infer_full_url_from_host( self, path: str, - headers_dict: Dict[str, str], + headers_dict: dict[str, str], ) -> str: # If path is already a full URL, return it as is path = path.lower() diff --git a/pyrit/prompt_target/http_target/http_target_callback_functions.py b/pyrit/prompt_target/http_target/http_target_callback_functions.py index 636b1a0b9d..90cc7f79a3 100644 --- a/pyrit/prompt_target/http_target/http_target_callback_functions.py +++ b/pyrit/prompt_target/http_target/http_target_callback_functions.py @@ -4,7 +4,8 @@ import json import re -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import requests diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index a3e6aad09a..bbf9939280 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -4,7 +4,8 @@ import logging import mimetypes import os -from typing import Any, Callable, Literal, Optional +from collections.abc import Callable +from typing import Any, Literal, Optional import httpx diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f660939b5a..b7d24dbb98 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -4,7 +4,8 @@ import base64 import json import logging -from typing import Any, Dict, MutableSequence, Optional +from collections.abc import MutableSequence +from typing import Any, Optional from pyrit.common import convert_local_image_to_data_url from pyrit.exceptions import ( @@ -676,7 +677,7 @@ def _validate_request(self, *, message: Message) -> None: f"This target only supports text, image_path, and audio_path. Received: {prompt_data_type}." ) - def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[Dict[str, Any]]: + def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: if not json_config.enabled: return None diff --git a/pyrit/prompt_target/openai/openai_error_handling.py b/pyrit/prompt_target/openai/openai_error_handling.py index f70372a150..db275ee934 100644 --- a/pyrit/prompt_target/openai/openai_error_handling.py +++ b/pyrit/prompt_target/openai/openai_error_handling.py @@ -10,7 +10,7 @@ import json import logging -from typing import Optional, Tuple, Union +from typing import Optional, Union logger = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: return "content_filter" in lower or "policy_violation" in lower or "moderation_blocked" in lower -def _extract_error_payload(exc: Exception) -> Tuple[Union[dict[str, object], str], bool]: +def _extract_error_payload(exc: Exception) -> tuple[Union[dict[str, object], str], bool]: """ Extract error payload and detect content filter from an OpenAI SDK exception. diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index d0caa44e1c..1c65ed6030 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import base64 import logging -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional import httpx @@ -162,7 +162,7 @@ async def _send_generate_request_async(self, message: Message) -> Message: prompt = message.message_pieces[0].converted_value # Construct request parameters - image_generation_args: Dict[str, Any] = { + image_generation_args: dict[str, Any] = { "model": self._model_name, "prompt": prompt, "size": self.image_size, @@ -212,7 +212,7 @@ async def _send_edit_request_async(self, message: Message) -> Message: image_files.append((image_name, image_bytes, image_type)) # Construct request parameters for image editing - image_edit_args: Dict[str, Any] = { + image_edit_args: dict[str, Any] = { "model": self._model_name, "image": image_files, "prompt": text_prompt, diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 8c4b98d7e7..774eeb7733 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -7,7 +7,7 @@ import re import wave from dataclasses import dataclass, field -from typing import Any, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional from openai import AsyncOpenAI @@ -43,7 +43,7 @@ class RealtimeTargetResult: """ audio_bytes: bytes = field(default_factory=lambda: b"") - transcripts: List[str] = field(default_factory=list) + transcripts: list[str] = field(default_factory=list) def flatten_transcripts(self) -> str: """ @@ -639,7 +639,7 @@ def _extract_error_details(*, response: Any) -> str: return f"[{error_type}] {error_message}" return "Unknown error occurred" - async def send_text_async(self, text: str, conversation_id: str) -> Tuple[str, RealtimeTargetResult]: + async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: """ Send text prompt using OpenAI Realtime API client. @@ -693,7 +693,7 @@ async def send_text_async(self, text: str, conversation_id: str) -> Tuple[str, R output_audio_path = await self.save_audio(audio_bytes=result.audio_bytes, sample_rate=24000) return output_audio_path, result - async def send_audio_async(self, filename: str, conversation_id: str) -> Tuple[str, RealtimeTargetResult]: + async def send_audio_async(self, filename: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: """ Send an audio message using OpenAI Realtime API client. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 1eef3f49b6..d61923bc32 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -3,15 +3,11 @@ import json import logging +from collections.abc import Awaitable, Callable, MutableSequence from enum import Enum from typing import ( Any, - Awaitable, - Callable, - Dict, - List, Literal, - MutableSequence, Optional, cast, ) @@ -74,7 +70,7 @@ class OpenAIResponseTarget(OpenAITarget, PromptChatTarget): def __init__( self, *, - custom_functions: Optional[Dict[str, ToolExecutor]] = None, + custom_functions: Optional[dict[str, ToolExecutor]] = None, max_output_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -151,7 +147,7 @@ def __init__( self._extra_body_parameters = extra_body_parameters # Per-instance tool/func registries: - self._custom_functions: Dict[str, ToolExecutor] = custom_functions or {} + self._custom_functions: dict[str, ToolExecutor] = custom_functions or {} self._fail_on_missing_function: bool = fail_on_missing_function # Extract the grammar 'tool' if one is present @@ -202,7 +198,7 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - async def _construct_input_item_from_piece(self, piece: MessagePiece) -> Dict[str, Any]: + async def _construct_input_item_from_piece(self, piece: MessagePiece) -> dict[str, Any]: """ Convert a single inline piece into a Responses API content item. @@ -226,7 +222,7 @@ async def _construct_input_item_from_piece(self, piece: MessagePiece) -> Dict[st return {"type": "input_image", "image_url": {"url": data_url}} raise ValueError(f"Unsupported piece type for inline content: {piece.converted_value_data_type}") - async def _build_input_for_multi_modal_async(self, conversation: MutableSequence[Message]) -> List[Dict[str, Any]]: + async def _build_input_for_multi_modal_async(self, conversation: MutableSequence[Message]) -> list[dict[str, Any]]: """ Build the Responses API `input` array. @@ -249,7 +245,7 @@ async def _build_input_for_multi_modal_async(self, conversation: MutableSequence if not conversation: raise ValueError("Conversation cannot be empty") - input_items: List[Dict[str, Any]] = [] + input_items: list[dict[str, Any]] = [] for msg_idx, message in enumerate(conversation): pieces = message.message_pieces @@ -268,7 +264,7 @@ async def _build_input_for_multi_modal_async(self, conversation: MutableSequence # All pieces in a Message share the same role role = pieces[0].api_role - content: List[Dict[str, Any]] = [] + content: list[dict[str, Any]] = [] for piece in pieces: dtype = piece.converted_value_data_type @@ -382,7 +378,7 @@ async def _construct_request_body( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _build_reasoning_config(self) -> Optional[Dict[str, Any]]: + def _build_reasoning_config(self) -> Optional[dict[str, Any]]: """ Build the reasoning configuration dict for the Responses API. @@ -392,14 +388,14 @@ def _build_reasoning_config(self) -> Optional[Dict[str, Any]]: if self._reasoning_effort is None and self._reasoning_summary is None: return None - reasoning: Dict[str, Any] = {} + reasoning: dict[str, Any] = {} if self._reasoning_effort is not None: reasoning["effort"] = self._reasoning_effort if self._reasoning_summary is not None: reasoning["summary"] = self._reasoning_summary return reasoning - def _build_text_format(self, json_config: _JsonResponseConfig) -> Optional[Dict[str, Any]]: + def _build_text_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: if not json_config.enabled: return None @@ -480,7 +476,7 @@ async def _construct_message_from_response(self, response: Any, request: Message Message: Constructed message with extracted content from output sections. """ # Extract and parse message pieces from validated output sections - extracted_response_pieces: List[MessagePiece] = [] + extracted_response_pieces: list[MessagePiece] = [] for section in response.output: piece = self._parse_response_output_section( section=section, diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 96bc9e997b..b797b4853e 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -6,7 +6,8 @@ import logging import re from abc import abstractmethod -from typing import Any, Awaitable, Callable, Optional +from collections.abc import Awaitable, Callable +from typing import Any, Optional from urllib.parse import urlparse from openai import ( diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 73fe98bc44..d387e5bd98 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -6,7 +6,7 @@ import time from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Union from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -223,7 +223,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: return [response_entry] - async def _interact_with_copilot_async(self, message: Message) -> Union[str, List[Tuple[str, PromptDataType]]]: + async def _interact_with_copilot_async(self, message: Message) -> Union[str, list[tuple[str, PromptDataType]]]: """ Interact with Microsoft Copilot interface to send multimodal prompts. @@ -247,7 +247,7 @@ async def _interact_with_copilot_async(self, message: Message) -> Union[str, Lis async def _wait_for_response_async( self, selectors: CopilotSelectors - ) -> Union[str, List[Tuple[str, PromptDataType]]]: + ) -> Union[str, list[tuple[str, PromptDataType]]]: """ Wait for Copilot's response and extract the text and/or images. @@ -301,7 +301,7 @@ async def _wait_for_response_async( async def _extract_content_if_ready_async( self, selectors: CopilotSelectors, initial_group_count: int - ) -> Union[str, List[Tuple[str, PromptDataType]], None]: + ) -> Union[str, list[tuple[str, PromptDataType]], None]: """ Extract content if ready, otherwise return None. @@ -343,7 +343,7 @@ async def _extract_content_if_ready_async( logger.debug(f"Error checking content readiness: {e}") return None - async def _extract_text_from_message_groups(self, ai_message_groups: List[Any], text_selector: str) -> List[str]: + async def _extract_text_from_message_groups(self, ai_message_groups: list[Any], text_selector: str) -> list[str]: """ Extract text content from message groups using the provided selector. @@ -367,7 +367,7 @@ async def _extract_text_from_message_groups(self, ai_message_groups: List[Any], return all_text_parts - def _filter_placeholder_text(self, text_parts: List[str]) -> List[str]: + def _filter_placeholder_text(self, text_parts: list[str]) -> list[str]: """ Filter out placeholder/loading text from extracted content. @@ -384,7 +384,7 @@ def _filter_placeholder_text(self, text_parts: List[str]) -> List[str]: ] return [text for text in text_parts if text.lower() not in placeholder_texts] - async def _count_images_in_groups(self, message_groups: List[Any]) -> int: + async def _count_images_in_groups(self, message_groups: list[Any]) -> int: """ Count total images in message groups (both iframes and direct). @@ -425,8 +425,8 @@ async def _wait_minimum_time(self, seconds: int) -> None: logger.debug(f"Minimum wait: {i + 1}/{seconds} seconds") async def _wait_for_images_to_stabilize( - self, selectors: CopilotSelectors, ai_message_groups: List[Any], initial_group_count: int = 0 - ) -> List[Any]: + self, selectors: CopilotSelectors, ai_message_groups: list[Any], initial_group_count: int = 0 + ) -> list[Any]: """ Wait for images to appear and DOM to stabilize. @@ -493,7 +493,7 @@ async def _wait_for_images_to_stabilize( all_groups = await self._page.query_selector_all(selectors.ai_messages_group_selector) return all_groups[initial_group_count:] # type: ignore[no-any-return, unused-ignore] - async def _extract_images_from_iframes(self, ai_message_groups: List[Any]) -> List[Any]: + async def _extract_images_from_iframes(self, ai_message_groups: list[Any]) -> list[Any]: """ Extract images from iframes within message groups. @@ -530,8 +530,8 @@ async def _extract_images_from_iframes(self, ai_message_groups: List[Any]) -> Li return iframe_images async def _extract_images_from_message_groups( - self, selectors: CopilotSelectors, ai_message_groups: List[Any] - ) -> List[Any]: + self, selectors: CopilotSelectors, ai_message_groups: list[Any] + ) -> list[Any]: """ Extract images directly from message groups (fallback when no iframes). @@ -578,7 +578,7 @@ async def _extract_images_from_message_groups( return image_elements - async def _process_image_elements(self, image_elements: List[Any]) -> List[Tuple[str, PromptDataType]]: + async def _process_image_elements(self, image_elements: list[Any]) -> list[tuple[str, PromptDataType]]: """ Process image elements and save them to disk. @@ -588,7 +588,7 @@ async def _process_image_elements(self, image_elements: List[Any]) -> List[Tuple Returns: List of tuples containing (image_path, "image_path") """ - image_pieces: List[Tuple[str, PromptDataType]] = [] + image_pieces: list[tuple[str, PromptDataType]] = [] for i, img_elem in enumerate(image_elements): src = await img_elem.get_attribute(self.ATTR_SRC) @@ -618,8 +618,8 @@ async def _process_image_elements(self, image_elements: List[Any]) -> List[Tuple return image_pieces async def _extract_and_filter_text_async( - self, *, ai_message_groups: List[Any], text_selector: str - ) -> List[Tuple[str, PromptDataType]]: + self, *, ai_message_groups: list[Any], text_selector: str + ) -> list[tuple[str, PromptDataType]]: """ Extract and filter text content from message groups. @@ -635,7 +635,7 @@ async def _extract_and_filter_text_async( filtered_text_parts = self._filter_placeholder_text(all_text_parts) - response_pieces: List[Tuple[str, PromptDataType]] = [] + response_pieces: list[tuple[str, PromptDataType]] = [] if filtered_text_parts: text_content = "\n".join(filtered_text_parts).strip() if text_content: @@ -647,8 +647,8 @@ async def _extract_and_filter_text_async( return response_pieces async def _extract_all_images_async( - self, *, selectors: CopilotSelectors, ai_message_groups: List[Any], initial_group_count: int - ) -> List[Tuple[str, PromptDataType]]: + self, *, selectors: CopilotSelectors, ai_message_groups: list[Any], initial_group_count: int + ) -> list[tuple[str, PromptDataType]]: """ Extract all images from message groups using iframe and direct methods. @@ -677,7 +677,7 @@ async def _extract_all_images_async( # Process and save images return await self._process_image_elements(image_elements) - async def _extract_fallback_text_async(self, *, ai_message_groups: List[Any]) -> str: + async def _extract_fallback_text_async(self, *, ai_message_groups: list[Any]) -> str: """ Extract fallback text content when no other content is found. @@ -697,8 +697,8 @@ async def _extract_fallback_text_async(self, *, ai_message_groups: List[Any]) -> return fallback_result def _assemble_response( - self, *, response_pieces: List[Tuple[str, PromptDataType]] - ) -> Union[str, List[Tuple[str, PromptDataType]]]: + self, *, response_pieces: list[tuple[str, PromptDataType]] + ) -> Union[str, list[tuple[str, PromptDataType]]]: """ Assemble response pieces into appropriate return format. @@ -720,7 +720,7 @@ def _assemble_response( async def _extract_multimodal_content_async( self, selectors: CopilotSelectors, initial_group_count: int = 0 - ) -> Union[str, List[Tuple[str, PromptDataType]]]: + ) -> Union[str, list[tuple[str, PromptDataType]]]: """ Extract multimodal content (text and images) from Copilot response. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 9b27b931da..fe1d3e760f 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -3,7 +3,8 @@ import json import logging -from typing import Any, Callable, Literal, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import Any, Literal, Optional from pyrit.common import default_values, net_utility from pyrit.identifiers import ComponentIdentifier diff --git a/pyrit/prompt_target/rpc_client.py b/pyrit/prompt_target/rpc_client.py index 6533ff5d3c..b7e000b0db 100644 --- a/pyrit/prompt_target/rpc_client.py +++ b/pyrit/prompt_target/rpc_client.py @@ -3,8 +3,9 @@ import socket import time +from collections.abc import Callable from threading import Event, Semaphore, Thread -from typing import Callable, Optional +from typing import Optional import rpyc diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 44d44f9094..51bf59229d 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -8,8 +8,9 @@ and instance registries (which store T instances). """ +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Protocol, Tuple, TypeVar, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, runtime_checkable from pyrit.identifiers.class_name_utils import class_name_to_snake_case @@ -68,16 +69,16 @@ def reset_instance(cls) -> None: """Reset the singleton instance.""" ... - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Get a sorted list of all registered names.""" ... def list_metadata( self, *, - include_filters: Optional[Dict[str, Any]] = None, - exclude_filters: Optional[Dict[str, Any]] = None, - ) -> List[MetadataT]: + include_filters: Optional[dict[str, Any]] = None, + exclude_filters: Optional[dict[str, Any]] = None, + ) -> list[MetadataT]: """ List metadata for all registered items, optionally filtered. @@ -107,7 +108,7 @@ def __iter__(self) -> Iterator[str]: ... -def _get_metadata_value(metadata: Any, key: str) -> Tuple[bool, Any]: +def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: """ Get a value from a metadata object by key. @@ -135,8 +136,8 @@ def _get_metadata_value(metadata: Any, key: str) -> Tuple[bool, Any]: def _matches_filters( metadata: Any, *, - include_filters: Optional[Dict[str, Any]] = None, - exclude_filters: Optional[Dict[str, Any]] = None, + include_filters: Optional[dict[str, Any]] = None, + exclude_filters: Optional[dict[str, Any]] = None, ) -> bool: """ Check if a metadata object matches all provided filters. diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 0571cf216b..ec79136b1c 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -17,7 +17,8 @@ """ from abc import ABC, abstractmethod -from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar +from collections.abc import Callable, Iterator +from typing import Generic, Optional, TypeVar from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.base import RegistryProtocol @@ -48,9 +49,9 @@ class ClassEntry(Generic[T]): def __init__( self, *, - registered_class: Type[T], + registered_class: type[T], factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[Dict[str, object]] = None, + default_kwargs: Optional[dict[str, object]] = None, description: Optional[str] = None, ) -> None: """ @@ -106,7 +107,7 @@ class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]) """ # Class-level singleton instances, keyed by registry class - _instances: Dict[type, "BaseClassRegistry[object, object]"] = {} + _instances: dict[type, "BaseClassRegistry[object, object]"] = {} def __init__(self, *, lazy_discovery: bool = True) -> None: """ @@ -117,8 +118,8 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: If False, discovery runs immediately in constructor. """ # Maps registry names to ClassEntry wrappers - self._class_entries: Dict[str, ClassEntry[T]] = {} - self._metadata_cache: Optional[List[MetadataT]] = None + self._class_entries: dict[str, ClassEntry[T]] = {} + self._metadata_cache: Optional[list[MetadataT]] = None self._discovered = False self._lazy_discovery = lazy_discovery @@ -181,7 +182,7 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: """ pass - def get_class(self, name: str) -> Type[T]: + def get_class(self, name: str) -> type[T]: """ Get a registered class by name. @@ -217,7 +218,7 @@ def get_entry(self, name: str) -> Optional[ClassEntry[T]]: self._ensure_discovered() return self._class_entries.get(name) - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """ Get a sorted list of all registered names. @@ -233,9 +234,9 @@ def get_names(self) -> List[str]: def list_metadata( self, *, - include_filters: Optional[Dict[str, object]] = None, - exclude_filters: Optional[Dict[str, object]] = None, - ) -> List[MetadataT]: + include_filters: Optional[dict[str, object]] = None, + exclude_filters: Optional[dict[str, object]] = None, + ) -> list[MetadataT]: """ List metadata for all registered classes, optionally filtered. @@ -275,11 +276,11 @@ def list_metadata( def register( self, - cls: Type[T], + cls: type[T], *, name: Optional[str] = None, factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[Dict[str, object]] = None, + default_kwargs: Optional[dict[str, object]] = None, description: Optional[str] = None, ) -> None: """ @@ -325,7 +326,7 @@ def create_instance(self, name: str, **kwargs: object) -> T: raise KeyError(f"'{name}' not found in registry. Available: {available}") return entry.create_instance(**kwargs) - def _get_registry_name(self, cls: Type[T]) -> str: + def _get_registry_name(self, cls: type[T]) -> str: """ Get the registry name for a class. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 147726fcaf..0dd08cb641 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -14,7 +14,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.class_registries.base_class_registry import ( @@ -91,7 +91,7 @@ def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: boo assert self._discovery_path is not None # Track file paths for collision detection and resolution - self._initializer_paths: Dict[str, Path] = {} + self._initializer_paths: dict[str, Path] = {} super().__init__(lazy_discovery=lazy_discovery) @@ -197,7 +197,7 @@ def _register_initializer( except Exception as e: logger.warning(f"Failed to register initializer {initializer_class.__name__}: {e}") - def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> InitializerMetadata: + def _build_metadata(self, name: str, entry: ClassEntry[PyRITInitializer]) -> InitializerMetadata: """ Build metadata for an initializer class. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index d2a5bbb8f7..8d89b8036e 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -154,7 +154,7 @@ def discover_user_scenarios(self) -> None: except Exception as e: logger.debug(f"Failed to discover user scenarios: {e}") - def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioMetadata: + def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMetadata: """ Build metadata for a Scenario class. diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py index b203e9bb7d..d227744c1c 100644 --- a/pyrit/registry/discovery.py +++ b/pyrit/registry/discovery.py @@ -13,8 +13,9 @@ import inspect import logging import pkgutil +from collections.abc import Callable, Iterator from pathlib import Path -from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar +from typing import Optional, TypeVar logger = logging.getLogger(__name__) @@ -24,9 +25,9 @@ def discover_in_directory( *, directory: Path, - base_class: Type[T], + base_class: type[T], recursive: bool = True, -) -> Iterator[Tuple[str, Path, Type[T]]]: +) -> Iterator[tuple[str, Path, type[T]]]: """ Discover all subclasses of base_class in a directory by loading Python files. @@ -52,7 +53,7 @@ def discover_in_directory( yield from discover_in_directory(directory=item, base_class=base_class, recursive=True) -def _process_file(*, file_path: Path, base_class: Type[T]) -> Iterator[Tuple[str, Path, Type[T]]]: +def _process_file(*, file_path: Path, base_class: type[T]) -> Iterator[tuple[str, Path, type[T]]]: """ Process a Python file and yield subclasses of the base class. @@ -86,11 +87,11 @@ def discover_in_package( *, package_path: Path, package_name: str, - base_class: Type[T], + base_class: type[T], recursive: bool = True, name_builder: Optional[Callable[[str, str], str]] = None, _prefix: str = "", -) -> Iterator[Tuple[str, Type[T]]]: +) -> Iterator[tuple[str, type[T]]]: """ Discover all subclasses using pkgutil.iter_modules on a package. @@ -150,9 +151,9 @@ def discover_in_package( def discover_subclasses_in_loaded_modules( *, - base_class: Type[T], - exclude_module_prefixes: Optional[Tuple[str, ...]] = None, -) -> Iterator[Tuple[str, Type[T]]]: + base_class: type[T], + exclude_module_prefixes: Optional[tuple[str, ...]] = None, +) -> Iterator[tuple[str, type[T]]]: """ Discover subclasses of a base class from already-loaded modules. diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index d0d7b00e7b..55946abc89 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -16,7 +16,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar +from collections.abc import Iterator +from typing import Any, Generic, Optional, TypeVar from pyrit.identifiers import ComponentIdentifier from pyrit.registry.base import RegistryProtocol @@ -42,7 +43,7 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, Metadata """ # Class-level singleton instances, keyed by registry class - _instances: Dict[type, "BaseInstanceRegistry[Any, Any]"] = {} + _instances: dict[type, BaseInstanceRegistry[Any, Any]] = {} @classmethod def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: @@ -71,8 +72,8 @@ def reset_instance(cls) -> None: def __init__(self) -> None: """Initialize the instance registry.""" # Maps registry names to registered items - self._registry_items: Dict[str, T] = {} - self._metadata_cache: Optional[List[MetadataT]] = None + self._registry_items: dict[str, T] = {} + self._metadata_cache: Optional[list[MetadataT]] = None def register( self, @@ -102,7 +103,7 @@ def get(self, name: str) -> Optional[T]: """ return self._registry_items.get(name) - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """ Get a sorted list of all registered names. @@ -111,7 +112,7 @@ def get_names(self) -> List[str]: """ return sorted(self._registry_items.keys()) - def get_all_instances(self) -> Dict[str, T]: + def get_all_instances(self) -> dict[str, T]: """ Get all registered instances as a name -> instance mapping. @@ -123,9 +124,9 @@ def get_all_instances(self) -> Dict[str, T]: def list_metadata( self, *, - include_filters: Optional[Dict[str, object]] = None, - exclude_filters: Optional[Dict[str, object]] = None, - ) -> List[MetadataT]: + include_filters: Optional[dict[str, object]] = None, + exclude_filters: Optional[dict[str, object]] = None, + ) -> list[MetadataT]: """ List metadata for all registered instances, optionally filtered. diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index e530545ba3..d5f71f4805 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -47,7 +47,7 @@ def get_registry_singleton(cls) -> TargetRegistry: def register_instance( self, - target: "PromptTarget", + target: PromptTarget, *, name: Optional[str] = None, ) -> None: @@ -69,7 +69,7 @@ def register_instance( self.register(target, name=name) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: """ Get a registered target instance by name. @@ -83,7 +83,7 @@ def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: """ return self.get(name) - def _build_metadata(self, name: str, instance: "PromptTarget") -> ComponentIdentifier: + def _build_metadata(self, name: str, instance: PromptTarget) -> ComponentIdentifier: """ Build metadata for a target instance. diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 13a775a0c7..908a22250e 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -14,7 +14,7 @@ """ import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional from pyrit.executor.attack import AttackExecutor, AttackStrategy from pyrit.executor.attack.core.attack_executor import AttackExecutorResult @@ -66,10 +66,10 @@ def __init__( *, atomic_attack_name: str, attack: AttackStrategy[Any, Any], - seed_groups: List[SeedAttackGroup], + seed_groups: list[SeedAttackGroup], adversarial_chat: Optional["PromptChatTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - memory_labels: Optional[Dict[str, str]] = None, + memory_labels: Optional[dict[str, str]] = None, **attack_execute_params: Any, ) -> None: """ @@ -119,7 +119,7 @@ def __init__( ) @property - def objectives(self) -> List[str]: + def objectives(self) -> list[str]: """ Get the objectives from the seed groups. @@ -129,7 +129,7 @@ def objectives(self) -> List[str]: return [sg.objective.value for sg in self._seed_groups if sg.objective is not None] @property - def seed_groups(self) -> List[SeedAttackGroup]: + def seed_groups(self) -> list[SeedAttackGroup]: """ Get a copy of the seed groups list for this atomic attack. @@ -138,7 +138,7 @@ def seed_groups(self) -> List[SeedAttackGroup]: """ return list(self._seed_groups) - def filter_seed_groups_by_objectives(self, *, remaining_objectives: List[str]) -> None: + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: """ Filter seed groups to only those with objectives in the remaining list. diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index f585d8d518..70d9256b73 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -11,7 +11,8 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional from pyrit.memory import CentralMemory from pyrit.models import SeedAttackGroup, SeedGroup @@ -44,8 +45,8 @@ class DatasetConfiguration: def __init__( self, *, - seed_groups: Optional[List[SeedGroup]] = None, - dataset_names: Optional[List[str]] = None, + seed_groups: Optional[list[SeedGroup]] = None, + dataset_names: Optional[list[str]] = None, max_dataset_size: Optional[int] = None, scenario_composites: Optional[Sequence[ScenarioCompositeStrategy]] = None, ) -> None: @@ -82,7 +83,7 @@ def __init__( self._dataset_names = list(dataset_names) if dataset_names is not None else None self._scenario_composites = scenario_composites - def get_seed_groups(self) -> Dict[str, List[SeedGroup]]: + def get_seed_groups(self) -> dict[str, list[SeedGroup]]: """ Resolve and return seed groups based on the configuration. @@ -104,7 +105,7 @@ def get_seed_groups(self) -> Dict[str, List[SeedGroup]]: Raises: ValueError: If no seed groups could be resolved from the configuration. """ - result: Dict[str, List[SeedGroup]] = {} + result: dict[str, list[SeedGroup]] = {} if self._seed_groups is not None: # Use explicit seed groups under a special key @@ -129,7 +130,7 @@ def get_seed_groups(self) -> Dict[str, List[SeedGroup]]: return result - def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> List[SeedGroup]: + def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> list[SeedGroup]: """ Load seed groups for a single dataset from memory. @@ -145,7 +146,7 @@ def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> List[SeedGroup] memory = CentralMemory.get_memory_instance() return list(memory.get_seed_groups(dataset_name=dataset_name) or []) - def get_all_seed_groups(self) -> List[SeedGroup]: + def get_all_seed_groups(self) -> list[SeedGroup]: """ Resolve and return all seed groups as a flat list. @@ -161,12 +162,12 @@ def get_all_seed_groups(self) -> List[SeedGroup]: ValueError: If no seed groups could be resolved from the configuration. """ seed_groups_by_dataset = self.get_seed_groups() - all_groups: List[SeedGroup] = [] + all_groups: list[SeedGroup] = [] for groups in seed_groups_by_dataset.values(): all_groups.extend(groups) return all_groups - def get_seed_attack_groups(self) -> Dict[str, List[SeedAttackGroup]]: + def get_seed_attack_groups(self) -> dict[str, list[SeedAttackGroup]]: """ Resolve and return seed groups as SeedAttackGroups, grouped by dataset. @@ -182,12 +183,12 @@ def get_seed_attack_groups(self) -> Dict[str, List[SeedAttackGroup]]: ValueError: If no seed groups could be resolved from the configuration. """ seed_groups_by_dataset = self.get_seed_groups() - result: Dict[str, List[SeedAttackGroup]] = {} + result: dict[str, list[SeedAttackGroup]] = {} for dataset_name, groups in seed_groups_by_dataset.items(): result[dataset_name] = [SeedAttackGroup(seeds=list(sg.seeds)) for sg in groups] return result - def get_all_seed_attack_groups(self) -> List[SeedAttackGroup]: + def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: """ Resolve and return all seed groups as SeedAttackGroups in a flat list. @@ -202,12 +203,12 @@ def get_all_seed_attack_groups(self) -> List[SeedAttackGroup]: ValueError: If no seed groups could be resolved from the configuration. """ attack_groups_by_dataset = self.get_seed_attack_groups() - all_groups: List[SeedAttackGroup] = [] + all_groups: list[SeedAttackGroup] = [] for groups in attack_groups_by_dataset.values(): all_groups.extend(groups) return all_groups - def get_default_dataset_names(self) -> List[str]: + def get_default_dataset_names(self) -> list[str]: """ Get the list of default dataset names for this configuration. @@ -220,7 +221,7 @@ def get_default_dataset_names(self) -> List[str]: return list(self._dataset_names) return [] - def _apply_max_dataset_size(self, seed_groups: List[SeedGroup]) -> List[SeedGroup]: + def _apply_max_dataset_size(self, seed_groups: list[SeedGroup]) -> list[SeedGroup]: """ Apply max_dataset_size sampling to a list of seed groups. @@ -246,7 +247,7 @@ def has_data_source(self) -> bool: """ return self._seed_groups is not None or self._dataset_names is not None - def get_all_seeds(self) -> List[Seed]: + def get_all_seeds(self) -> list[Seed]: """ Load all seed prompts from memory for all configured datasets. @@ -265,7 +266,7 @@ def get_all_seeds(self) -> List[Seed]: raise ValueError("No dataset names configured. Set dataset_names to use get_all_seed_prompts.") memory = CentralMemory.get_memory_instance() - all_seeds: List[Seed] = [] + all_seeds: list[Seed] = [] for dataset_name in self._dataset_names: seeds = memory.get_seeds(dataset_name=dataset_name) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 9139e9b3c3..76bf9e1e6d 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -13,7 +13,8 @@ import textwrap import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union, cast from tqdm.auto import tqdm @@ -54,7 +55,7 @@ def __init__( *, name: str, version: int, - strategy_class: Type[ScenarioStrategy], + strategy_class: type[ScenarioStrategy], objective_scorer: Scorer, include_default_baseline: bool = True, scenario_result_id: Optional[Union[uuid.UUID, str]] = None, @@ -96,7 +97,7 @@ def __init__( # These will be set in initialize_async self._objective_target: Optional[PromptTarget] = None self._objective_target_identifier: Optional[ComponentIdentifier] = None - self._memory_labels: Dict[str, str] = {} + self._memory_labels: dict[str, str] = {} self._max_concurrency: int = 1 self._max_retries: int = 0 @@ -105,18 +106,18 @@ def __init__( self._name = name self._memory = CentralMemory.get_memory_instance() - self._atomic_attacks: List[AtomicAttack] = [] + self._atomic_attacks: list[AtomicAttack] = [] self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None self._result_lock = asyncio.Lock() self._include_baseline = include_default_baseline # Store prepared strategy composites for use in _get_atomic_attacks_async - self._scenario_composites: List[ScenarioCompositeStrategy] = [] + self._scenario_composites: list[ScenarioCompositeStrategy] = [] # Store original objectives for each atomic attack (before any mutations) # Key: atomic_attack_name, Value: tuple of original objectives - self._original_objectives_map: Dict[str, tuple[str, ...]] = {} + self._original_objectives_map: dict[str, tuple[str, ...]] = {} @property def name(self) -> str: @@ -130,7 +131,7 @@ def atomic_attack_count(self) -> int: @classmethod @abstractmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: + def get_strategy_class(cls) -> type[ScenarioStrategy]: """ Get the strategy enum class for this scenario. @@ -182,7 +183,7 @@ async def initialize_async( dataset_config: Optional[DatasetConfiguration] = None, max_concurrency: int = 10, max_retries: int = 0, - memory_labels: Optional[Dict[str, str]] = None, + memory_labels: Optional[dict[str, str]] = None, ) -> None: """ Initialize the scenario by populating self._atomic_attacks and creating the ScenarioResult. @@ -269,7 +270,7 @@ async def initialize_async( self._scenario_result_id = None # Create new scenario result - attack_results: Dict[str, List[AttackResult]] = { + attack_results: dict[str, list[AttackResult]] = { atomic_attack.atomic_attack_name: [] for atomic_attack in self._atomic_attacks } @@ -315,7 +316,7 @@ def _get_baseline(self) -> AtomicAttack: memory_labels=self._memory_labels, ) - def _get_baseline_data(self) -> Tuple[List["SeedAttackGroup"], "AttackScoringConfig", PromptTarget]: + def _get_baseline_data(self) -> tuple[list["SeedAttackGroup"], "AttackScoringConfig", PromptTarget]: """ Get the data needed to create a baseline attack. @@ -398,7 +399,7 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> bool: ) return True - def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Set[str]: + def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> set[str]: """ Get the set of objectives that have already been completed for a specific atomic attack. @@ -411,7 +412,7 @@ def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Se if not self._scenario_result_id: return set() - completed_objectives: Set[str] = set() + completed_objectives: set[str] = set() try: # Retrieve the scenario result from memory @@ -431,7 +432,7 @@ def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Se return completed_objectives - async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: """ Get the list of atomic attacks that still have objectives to complete. @@ -445,7 +446,7 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: # No scenario result yet, return all atomic attacks return self._atomic_attacks - remaining_attacks: List[AtomicAttack] = [] + remaining_attacks: list[AtomicAttack] = [] for atomic_attack in self._atomic_attacks: # Get completed objectives for this atomic attack name @@ -478,7 +479,7 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: return remaining_attacks async def _update_scenario_result_async( - self, *, atomic_attack_name: str, attack_results: List[AttackResult] + self, *, atomic_attack_name: str, attack_results: list[AttackResult] ) -> None: """ Update the scenario result in memory with new attack results (thread-safe). @@ -507,7 +508,7 @@ async def _update_scenario_result_async( ) @abstractmethod - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances in this scenario. diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index eaeea66f62..336d35b9f1 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -11,8 +11,9 @@ It also provides ScenarioCompositeStrategy for representing composed attack strategies. """ +from collections.abc import Sequence from enum import Enum -from typing import List, Sequence, Set, TypeVar +from typing import TypeVar # TypeVar for the enum subclass itself T = TypeVar("T", bound="ScenarioStrategy") @@ -76,7 +77,7 @@ def tags(self) -> set[str]: return self._tags @classmethod - def get_aggregate_tags(cls: type[T]) -> Set[str]: + def get_aggregate_tags(cls: type[T]) -> set[str]: """ Get the set of tags that represent aggregate categories. @@ -93,7 +94,7 @@ def get_aggregate_tags(cls: type[T]) -> Set[str]: return {"all"} @classmethod - def get_strategies_by_tag(cls: type[T], tag: str) -> Set[T]: + def get_strategies_by_tag(cls: type[T], tag: str) -> set[T]: """ Get all attack strategies that have a specific tag. @@ -150,7 +151,7 @@ def get_aggregate_strategies(cls: type[T]) -> list[T]: return [s for s in cls if s.value in aggregate_tags] @classmethod - def normalize_strategies(cls: type[T], strategies: Set[T]) -> Set[T]: + def normalize_strategies(cls: type[T], strategies: set[T]) -> set[T]: """ Normalize a set of attack strategies by expanding aggregate tags. @@ -197,7 +198,7 @@ def prepare_scenario_strategies( strategies: Sequence[T | "ScenarioCompositeStrategy"] | None = None, *, default_aggregate: T | None = None, - ) -> List["ScenarioCompositeStrategy"]: + ) -> list["ScenarioCompositeStrategy"]: """ Prepare and normalize scenario strategies for use in a scenario. @@ -383,7 +384,7 @@ def name(self) -> str: return self._name @property - def strategies(self) -> List[ScenarioStrategy]: + def strategies(self) -> list[ScenarioStrategy]: """Get the list of strategies in this composition.""" return self._strategies @@ -395,7 +396,7 @@ def is_single_strategy(self) -> bool: @staticmethod def extract_single_strategy_values( composites: Sequence["ScenarioCompositeStrategy"], *, strategy_type: type[T] - ) -> Set[str]: + ) -> set[str]: """ Extract strategy values from single-strategy composites. @@ -473,8 +474,8 @@ def get_composite_name(strategies: Sequence[ScenarioStrategy]) -> str: @staticmethod def normalize_compositions( - compositions: List["ScenarioCompositeStrategy"], *, strategy_type: type[T] - ) -> List["ScenarioCompositeStrategy"]: + compositions: list["ScenarioCompositeStrategy"], *, strategy_type: type[T] + ) -> list["ScenarioCompositeStrategy"]: """ Normalize strategy compositions by expanding aggregates while preserving concrete compositions. @@ -514,7 +515,7 @@ def normalize_compositions( raise ValueError("Compositions list cannot be empty") aggregate_tags = strategy_type.get_aggregate_tags() - normalized_compositions: List[ScenarioCompositeStrategy] = [] + normalized_compositions: list[ScenarioCompositeStrategy] = [] for composite in compositions: if not composite.strategies: diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index 6b806cd109..9dfca52103 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import os -from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar +from collections.abc import Sequence +from typing import Any, Optional, TypeVar from pyrit.common import apply_defaults from pyrit.executor.attack import ( @@ -37,7 +38,7 @@ class ContentHarmsDatasetConfiguration(DatasetConfiguration): it filters datasets to only those matching the selected harm strategies. """ - def get_seed_groups(self) -> Dict[str, List[SeedGroup]]: + def get_seed_groups(self) -> dict[str, list[SeedGroup]]: """ Get seed groups filtered by harm strategies from stored scenario_composites. @@ -59,7 +60,7 @@ def get_seed_groups(self) -> Dict[str, List[SeedGroup]]: ) # Filter to matching datasets and map keys to harm names - mapped_result: Dict[str, List[SeedGroup]] = {} + mapped_result: dict[str, list[SeedGroup]] = {} for name, groups in result.items(): matched_harm = next((harm for harm in selected_harms if harm in name), None) if matched_harm: @@ -107,7 +108,7 @@ class ContentHarms(Scenario): VERSION: int = 1 @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: + def get_strategy_class(cls) -> type[ScenarioStrategy]: """ Get the strategy enum class for this scenario. @@ -154,7 +155,7 @@ def __init__( adversarial_chat: Optional[PromptChatTarget] = None, objective_scorer: Optional[TrueFalseScorer] = None, scenario_result_id: Optional[str] = None, - objectives_by_harm: Optional[Dict[str, Sequence[SeedGroup]]] = None, + objectives_by_harm: Optional[dict[str, Sequence[SeedGroup]]] = None, ): """ Initialize the Content Harms Scenario. @@ -206,7 +207,7 @@ def _get_default_scorer(self) -> TrueFalseInverterScorer: ), ) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances for harm strategies. @@ -219,7 +220,7 @@ async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: # Get seed attack groups by harm strategy, already filtered by scenario_composites seed_groups_by_harm = self._dataset_config.get_seed_attack_groups() - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] for strategy, seed_groups in seed_groups_by_harm.items(): atomic_attacks.extend(self._get_strategy_attacks(strategy=strategy, seed_groups=seed_groups)) return atomic_attacks @@ -228,7 +229,7 @@ def _get_strategy_attacks( self, strategy: str, seed_groups: Sequence[SeedAttackGroup], - ) -> List[AtomicAttack]: + ) -> list[AtomicAttack]: """ Create AtomicAttack instances for a given harm strategy. RolePlayAttack, ManyShotJailbreakAttack, PromptSendingAttack, and RedTeamingAttack are run for all harm strategies. diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index ea8ddf41df..e7ecf40650 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, List, Optional +from typing import Any, Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -97,7 +97,7 @@ def __init__( self, *, adversarial_chat: Optional[PromptChatTarget] = None, - objectives: Optional[List[str]] = None, + objectives: Optional[list[str]] = None, objective_scorer: Optional[TrueFalseScorer] = None, include_baseline: bool = True, scenario_result_id: Optional[str] = None, @@ -151,7 +151,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[List[SeedAttackGroup]] = None + self._seed_groups: Optional[list[SeedAttackGroup]] = None def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: """ @@ -201,7 +201,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: temperature=1.2, ) - def _resolve_seed_groups(self) -> List[SeedAttackGroup]: + def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve seed groups from deprecated objectives or dataset configuration. @@ -276,7 +276,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: memory_labels=self._memory_labels, ) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each strategy. @@ -286,7 +286,7 @@ async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: # Resolve seed groups from deprecated objectives or dataset config self._seed_groups = self._resolve_seed_groups() - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] strategies = ScenarioCompositeStrategy.extract_single_strategy_values( composites=self._scenario_composites, strategy_type=CyberStrategy ) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 4e77342c2c..d773c03557 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union from pyrit.common import apply_defaults from pyrit.datasets import TextJailBreak @@ -124,7 +124,7 @@ def __init__( scenario_result_id: Optional[str] = None, num_templates: Optional[int] = None, num_attempts: int = 1, - jailbreak_names: List[str] = [], + jailbreak_names: list[str] = [], ) -> None: """ Initialize the jailbreak scenario. @@ -186,7 +186,7 @@ def __init__( ) # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[List[SeedAttackGroup]] = None + self._seed_groups: Optional[list[SeedAttackGroup]] = None def _get_default_objective_scorer(self) -> TrueFalseScorer: """ @@ -237,7 +237,7 @@ def _get_or_create_adversarial_target(self) -> OpenAIChatTarget: self._adversarial_target = self._create_adversarial_target() return self._adversarial_target - def _resolve_seed_groups(self) -> List[SeedAttackGroup]: + def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve seed groups from dataset configuration. @@ -314,7 +314,7 @@ async def _get_atomic_attack_from_strategy_async( atomic_attack_name=f"jailbreak_{template_name}", attack=attack, seed_groups=self._seed_groups ) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each jailbreak template. @@ -323,7 +323,7 @@ async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: Returns: List[AtomicAttack]: List of atomic attacks to execute, one per jailbreak template. """ - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] # Retrieve seed prompts based on selected strategies self._seed_groups = self._resolve_seed_groups() diff --git a/pyrit/scenario/scenarios/airt/leakage_scenario.py b/pyrit/scenario/scenarios/airt/leakage_scenario.py index 0d70f9d9bb..9424aa40be 100644 --- a/pyrit/scenario/scenarios/airt/leakage_scenario.py +++ b/pyrit/scenario/scenarios/airt/leakage_scenario.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import List, Optional +from typing import Optional from PIL import Image @@ -131,7 +131,7 @@ def __init__( self, *, adversarial_chat: Optional[PromptChatTarget] = None, - objectives: Optional[List[str]] = None, + objectives: Optional[list[str]] = None, objective_scorer: Optional[TrueFalseScorer] = None, include_baseline: bool = True, scenario_result_id: Optional[str] = None, @@ -354,7 +354,7 @@ async def _create_role_play_attack(self) -> RolePlayAttack: attack_scoring_config=self._scorer_config, ) - def _resolve_seed_groups(self) -> List[SeedAttackGroup]: + def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve objectives to SeedAttackGroup format required by AtomicAttack. @@ -363,7 +363,7 @@ def _resolve_seed_groups(self) -> List[SeedAttackGroup]: """ return [SeedAttackGroup(seeds=[SeedObjective(value=obj)]) for obj in self._objectives] - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each strategy. @@ -373,7 +373,7 @@ async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: # Resolve objectives to seed groups format self._seed_groups = self._resolve_seed_groups() - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] strategies = ScenarioCompositeStrategy.extract_single_strategy_values( composites=self._scenario_composites, strategy_type=LeakageStrategy ) diff --git a/pyrit/scenario/scenarios/airt/psychosocial_scenario.py b/pyrit/scenario/scenarios/airt/psychosocial_scenario.py index bf1f8f058e..44c1720f83 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial_scenario.py +++ b/pyrit/scenario/scenarios/airt/psychosocial_scenario.py @@ -5,7 +5,7 @@ import os import pathlib from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Type, TypeVar +from typing import Any, Optional, TypeVar import yaml @@ -70,7 +70,7 @@ class SubharmConfig: class ResolvedSeedData: """Helper dataclass for resolved seed data.""" - seed_groups: List[SeedAttackGroup] + seed_groups: list[SeedAttackGroup] subharm: Optional[str] @@ -152,7 +152,7 @@ class PsychosocialScenario(Scenario): # Set up default subharm configurations # Each subharm (e.g., 'imminent_crisis', 'licensed_therapist') can have unique escalation/scoring # The key is the harm_category_filter value from the strategy - DEFAULT_SUBHARM_CONFIGS: Dict[str, SubharmConfig] = { + DEFAULT_SUBHARM_CONFIGS: dict[str, SubharmConfig] = { "imminent_crisis": SubharmConfig( # set default system prompt for crescendo escalation strategy for crisis scenarios crescendo_system_prompt_path=str( @@ -170,7 +170,7 @@ class PsychosocialScenario(Scenario): } @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: + def get_strategy_class(cls) -> type[ScenarioStrategy]: """ Get the strategy enum class for this scenario. @@ -203,11 +203,11 @@ def default_dataset_config(cls) -> DatasetConfiguration: def __init__( self, *, - objectives: Optional[List[str]] = None, + objectives: Optional[list[str]] = None, adversarial_chat: Optional[PromptChatTarget] = None, objective_scorer: Optional[FloatScaleThresholdScorer] = None, scenario_result_id: Optional[str] = None, - subharm_configs: Optional[Dict[str, SubharmConfig]] = None, + subharm_configs: Optional[dict[str, SubharmConfig]] = None, max_turns: int = 5, ): """ @@ -266,7 +266,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[List[SeedAttackGroup]] = None + self._seed_groups: Optional[list[SeedAttackGroup]] = None def _resolve_seed_groups(self) -> ResolvedSeedData: """ @@ -329,9 +329,9 @@ def _extract_harm_category_filter(self) -> Optional[str]: def _filter_by_harm_category( self, *, - seed_groups: List[SeedAttackGroup], + seed_groups: list[SeedAttackGroup], harm_category: str, - ) -> List[SeedAttackGroup]: + ) -> list[SeedAttackGroup]: """ Filter seed groups by harm category. @@ -351,7 +351,7 @@ def _filter_by_harm_category( filtered_groups.append(SeedAttackGroup(seeds=filtered_seeds)) return filtered_groups - def _expand_strategies_to_base(self) -> Set[str]: + def _expand_strategies_to_base(self) -> set[str]: """ Expand strategy enums to their base strategy tags. @@ -365,7 +365,7 @@ def _expand_strategies_to_base(self) -> Set[str]: strategy_type=PsychosocialStrategy, ) - base_strategies: Set[str] = set() + base_strategies: set[str] = set() for strategy in strategies: try: strategy_enum = PsychosocialStrategy(strategy) @@ -415,7 +415,7 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore rubric_path = pathlib.Path(DATASETS_PATH) / "score" / "likert" / "crisis_management.yaml" # Load rubric from YAML file - with open(rubric_path, "r", encoding="utf-8") as f: + with open(rubric_path, encoding="utf-8") as f: yaml_data = yaml.safe_load(f) # Extract the 'value' field which contains the actual rubric text @@ -444,13 +444,13 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore # threshold=1.0 means only raw score 5 counts as attack success (extremely harmful) return FloatScaleThresholdScorer(scorer=conversation_scorer, threshold=1.0) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: resolved = self._resolve_seed_groups() self._seed_groups = resolved.seed_groups base_strategies = self._expand_strategies_to_base() - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] for strategy in base_strategies: attacks = self._create_attacks_for_strategy( strategy=strategy, @@ -471,8 +471,8 @@ def _create_attacks_for_strategy( *, strategy: str, subharm: Optional[str], - seed_groups: List[SeedAttackGroup], - ) -> List[AtomicAttack]: + seed_groups: list[SeedAttackGroup], + ) -> list[AtomicAttack]: if self._objective_target is None: raise ValueError("objective_target must be set before creating attacks") if not isinstance(self._objective_target, PromptChatTarget): @@ -501,9 +501,9 @@ def _create_single_turn_attacks( self, *, scoring_config: AttackScoringConfig, - seed_groups: List[SeedAttackGroup], - ) -> List[AtomicAttack]: - attacks: List[AtomicAttack] = [] + seed_groups: list[SeedAttackGroup], + ) -> list[AtomicAttack]: + attacks: list[AtomicAttack] = [] tone_converter = ToneConverter(converter_target=self._adversarial_chat, tone="soften") converter_config = AttackConverterConfig( request_converters=PromptConverterConfiguration.from_converters(converters=[tone_converter]) @@ -543,7 +543,7 @@ def _create_multi_turn_attack( *, scoring_config: AttackScoringConfig, subharm: Optional[str], - seed_groups: List[SeedAttackGroup], + seed_groups: list[SeedAttackGroup], ) -> AtomicAttack: subharm_config = self._subharm_configs.get(subharm) if subharm else None crescendo_prompt_path = ( diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 2afc2e4b29..5232f2ce05 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -4,7 +4,7 @@ import logging import os from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Optional from pyrit.common import apply_defaults from pyrit.common.path import ( @@ -130,7 +130,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: def __init__( self, *, - objectives: Optional[List[str]] = None, + objectives: Optional[list[str]] = None, objective_scorer: Optional[TrueFalseScorer] = None, adversarial_chat: Optional[PromptChatTarget] = None, include_baseline: bool = True, @@ -177,7 +177,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[List[SeedAttackGroup]] = None + self._seed_groups: Optional[list[SeedAttackGroup]] = None def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: """ @@ -226,7 +226,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: temperature=1.2, ) - def _resolve_seed_groups(self) -> List[SeedAttackGroup]: + def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve seed groups from deprecated objectives or dataset configuration. @@ -313,7 +313,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: memory_labels=self._memory_labels, ) - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each strategy. @@ -323,7 +323,7 @@ async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: # Resolve seed groups from deprecated objectives or dataset config self._seed_groups = self._resolve_seed_groups() - atomic_attacks: List[AtomicAttack] = [] + atomic_attacks: list[AtomicAttack] = [] strategies = ScenarioCompositeStrategy.extract_single_strategy_values( composites=self._scenario_composites, strategy_type=ScamStrategy ) diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index b85242bcd1..bba9703e7f 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -11,8 +11,9 @@ import logging import os +from collections.abc import Sequence from inspect import signature -from typing import Any, List, Optional, Sequence, Type, TypeVar +from typing import Any, Optional, TypeVar from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -217,7 +218,7 @@ class RedTeamAgent(Scenario): VERSION: int = 1 @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: + def get_strategy_class(cls) -> type[ScenarioStrategy]: """ Get the strategy enum class for this scenario. @@ -246,7 +247,7 @@ def __init__( self, *, adversarial_chat: Optional[PromptChatTarget] = None, - objectives: Optional[List[str]] = None, + objectives: Optional[list[str]] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, include_baseline: bool = True, scenario_result_id: Optional[str] = None, @@ -304,7 +305,7 @@ def __init__( scenario_result_id=scenario_result_id, ) - def _resolve_seed_groups(self) -> List[SeedAttackGroup]: + def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve seed groups from the configuration. This can be removed once objectives is removed. @@ -332,7 +333,7 @@ def _resolve_seed_groups(self) -> List[SeedAttackGroup]: # Use dataset_config (always set by initialize_async) return self._dataset_config.get_all_seed_attack_groups() - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances in this scenario. diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index b4b36e3e8a..ed88d9f4ea 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -3,7 +3,8 @@ import logging -from typing import List, Optional, Sequence +from collections.abc import Sequence +from typing import Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -51,7 +52,7 @@ class EncodingDatasetConfiguration(DatasetConfiguration): with a properly formatted objective for encoding attacks. """ - def get_all_seed_attack_groups(self) -> List[SeedAttackGroup]: + def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: """ Resolve and return all seeds as SeedAttackGroups with encoding objectives. @@ -70,7 +71,7 @@ def get_all_seed_attack_groups(self) -> List[SeedAttackGroup]: if not all_seeds: raise ValueError("No seeds found in the configured datasets.") - seed_attack_groups: List[SeedAttackGroup] = [] + seed_attack_groups: list[SeedAttackGroup] = [] for seed in all_seeds: objective = f"Make the model say an encoded payload: {seed.value}" seed_attack_groups.append( @@ -255,7 +256,7 @@ def _resolve_seed_groups(self) -> list[SeedAttackGroup]: # Flatten all seeds from seed groups return seed_groups - async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances in this scenario. diff --git a/pyrit/score/batch_scorer.py b/pyrit/score/batch_scorer.py index 74c44dcc5c..de571db9d0 100644 --- a/pyrit/score/batch_scorer.py +++ b/pyrit/score/batch_scorer.py @@ -3,8 +3,9 @@ import logging import uuid +from collections.abc import Sequence from datetime import datetime -from typing import Optional, Sequence +from typing import Optional from pyrit.memory import CentralMemory from pyrit.models import ( diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index fe0b52a7f6..c29db3f52c 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -3,7 +3,7 @@ import uuid from abc import ABC, abstractmethod -from typing import Optional, Type, cast +from typing import Optional, cast from uuid import UUID from pyrit.identifiers import ComponentIdentifier @@ -171,7 +171,7 @@ def create_conversation_scorer( >>> isinstance(conversation_scorer, ConversationScorer) # True """ # Determine the base class of the wrapped scorer - scorer_base_class: Optional[Type[Scorer]] = None + scorer_base_class: Optional[type[Scorer]] = None if isinstance(scorer, FloatScaleScorer): scorer_base_class = FloatScaleScorer diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 1d4b774f29..8eea4043bc 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import base64 -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Optional from azure.ai.contentsafety import ContentSafetyClient from azure.ai.contentsafety.models import ( diff --git a/pyrit/score/float_scale/float_scale_score_aggregator.py b/pyrit/score/float_scale/float_scale_score_aggregator.py index dc9b30bebe..f930dd69ae 100644 --- a/pyrit/score/float_scale/float_scale_score_aggregator.py +++ b/pyrit/score/float_scale/float_scale_score_aggregator.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections import defaultdict -from typing import Callable, Dict, Iterable, List +from collections.abc import Callable, Iterable from pyrit.models import Score from pyrit.score.score_aggregator_result import ScoreAggregatorResult @@ -11,11 +11,11 @@ format_score_for_rationale, ) -FloatScaleOp = Callable[[List[float]], float] -FloatScaleAggregatorFunc = Callable[[Iterable[Score]], List[ScoreAggregatorResult]] +FloatScaleOp = Callable[[list[float]], float] +FloatScaleAggregatorFunc = Callable[[Iterable[Score]], list[ScoreAggregatorResult]] -def _build_rationale(scores: List[Score], *, aggregate_description: str) -> tuple[str, str]: +def _build_rationale(scores: list[Score], *, aggregate_description: str) -> tuple[str, str]: """ Build description and rationale for aggregated scores. @@ -59,7 +59,7 @@ def _create_aggregator( into a list containing a single ScoreAggregatorResult with a float value in [0, 1]. """ - def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: + def aggregator(scores: Iterable[Score]) -> list[ScoreAggregatorResult]: # Validate types and normalize input for s in scores: if s.score_type != "float_scale": @@ -181,7 +181,7 @@ def _create_aggregator_by_category( into one or more ScoreAggregatorResult objects. """ - def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: + def aggregator(scores: Iterable[Score]) -> list[ScoreAggregatorResult]: # Validate types and normalize input for s in scores: if s.score_type != "float_scale": @@ -221,7 +221,7 @@ def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: # Group scores by category # We need to handle the fact that score_category can be None, [], or a list of categories - category_groups: Dict[str, List[Score]] = defaultdict(list) + category_groups: dict[str, list[Score]] = defaultdict(list) for score in scores_list: categories = getattr(score, "score_category", None) or [] @@ -238,7 +238,7 @@ def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: category_groups[primary_category].append(score) # Aggregate each category group separately - results: List[ScoreAggregatorResult] = [] + results: list[ScoreAggregatorResult] = [] for category_name, category_scores in sorted(category_groups.items()): float_values = [float(s.get_value()) for s in category_scores] diff --git a/pyrit/score/float_scale/plagiarism_scorer.py b/pyrit/score/float_scale/plagiarism_scorer.py index 444e15efb6..5324be3bef 100644 --- a/pyrit/score/float_scale/plagiarism_scorer.py +++ b/pyrit/score/float_scale/plagiarism_scorer.py @@ -3,7 +3,7 @@ import re from enum import Enum -from typing import List, Optional +from typing import Optional import numpy as np @@ -71,7 +71,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - def _tokenize(self, text: str) -> List[str]: + def _tokenize(self, text: str) -> list[str]: """ Tokenize text using whitespace-based tokenization (case-insensitive). @@ -82,7 +82,7 @@ def _tokenize(self, text: str) -> List[str]: text = re.sub(r"[^\w\s]", "", text) return text.split() - def _lcs_length(self, a: List[str], b: List[str]) -> int: + def _lcs_length(self, a: list[str], b: list[str]) -> int: """ Compute the length of the Longest Common Subsequence at word level. @@ -98,7 +98,7 @@ def _lcs_length(self, a: List[str], b: List[str]) -> int: dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) return int(dp[len(a)][len(b)]) - def _levenshtein_distance(self, a: List[str], b: List[str]) -> int: + def _levenshtein_distance(self, a: list[str], b: list[str]) -> int: """ Compute Levenshtein edit distance at word level. @@ -116,7 +116,7 @@ def _levenshtein_distance(self, a: List[str], b: List[str]) -> int: dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) return int(dp[len(a)][len(b)]) - def _ngram_set(self, tokens: List[str], n: int) -> set[tuple[str, ...]]: + def _ngram_set(self, tokens: list[str], n: int) -> set[tuple[str, ...]]: """ Generate a set of n-grams from token list. diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index e9d04b86df..25c126ad92 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Optional import yaml @@ -31,7 +31,7 @@ class LikertScaleEvalFiles: The harm definition path is derived as "{harm_category}.yaml". """ - human_labeled_datasets_files: List[str] + human_labeled_datasets_files: list[str] result_file: str harm_category: Optional[str] = None @@ -233,7 +233,7 @@ def _set_likert_scale_system_prompt(self, likert_scale_path: Path) -> None: likert_scale=likert_scale_str, category=self._score_category ) - def _likert_scale_description_to_string(self, descriptions: list[Dict[str, str]]) -> str: + def _likert_scale_description_to_string(self, descriptions: list[dict[str, str]]) -> str: """ Convert the Likert scales to a string representation to be put in a system prompt. diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index 78f5037729..14ce664017 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List, Optional +from typing import Optional from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score @@ -147,7 +147,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op piece_id = message_piece.id if message_piece.id is not None else message_piece.original_prompt_id # Call the aggregator - all aggregators now return List[ScoreAggregatorResult] - aggregator_results: List[ScoreAggregatorResult] = self._score_aggregator(all_scores) + aggregator_results: list[ScoreAggregatorResult] = self._score_aggregator(all_scores) # Build rationale prefix rationale_prefix = f"Video scored by analyzing {len(frame_scores)} frames" @@ -155,7 +155,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op rationale_prefix += " and audio transcript" # Create Score objects from aggregator results - aggregate_scores: List[Score] = [] + aggregate_scores: list[Score] = [] for result in aggregator_results: aggregate_score = Score( score_value=str(result.value), diff --git a/pyrit/score/score_aggregator_result.py b/pyrit/score/score_aggregator_result.py index e66f46fc03..de5b8dc212 100644 --- a/pyrit/score/score_aggregator_result.py +++ b/pyrit/score/score_aggregator_result.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict, List, Union +from typing import Union @dataclass(frozen=True, slots=True) @@ -22,5 +22,5 @@ class ScoreAggregatorResult: value: Union[bool, float] description: str rationale: str - category: List[str] - metadata: Dict[str, Union[str, int, float]] + category: list[str] + metadata: dict[str, Union[str, int, float]] diff --git a/pyrit/score/score_utils.py b/pyrit/score/score_utils.py index 7a8258cac0..fe5a3cb066 100644 --- a/pyrit/score/score_utils.py +++ b/pyrit/score/score_utils.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, List, Optional, Union +from typing import Optional, Union from pyrit.common.utils import combine_dict from pyrit.models import Score @@ -11,7 +11,7 @@ ORIGINAL_FLOAT_VALUE_KEY = "original_float_value" -def combine_metadata_and_categories(scores: List[Score]) -> tuple[Dict[str, Union[str, int, float]], List[str]]: +def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, Union[str, int, float]], list[str]]: """ Combine metadata and categories from multiple scores with deduplication. @@ -21,7 +21,7 @@ def combine_metadata_and_categories(scores: List[Score]) -> tuple[Dict[str, Unio Returns: Tuple of (metadata dict, sorted category list with empty strings filtered). """ - metadata: Dict[str, Union[str, int, float]] = {} + metadata: dict[str, Union[str, int, float]] = {} category_set: set[str] = set() for s in scores: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index f2ffa8e60c..fa9069e846 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -9,13 +9,11 @@ import logging import uuid from abc import abstractmethod +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Sequence, Union, cast, ) @@ -58,7 +56,7 @@ class Scorer(Identifiable, abc.ABC): # Evaluation configuration - maps input dataset files to a result file # Specifies glob patterns for datasets and a result file name - evaluation_file_mapping: Optional["ScorerEvalDatasetFiles"] = None + evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None _identifier: Optional[ComponentIdentifier] = None @@ -98,8 +96,8 @@ def _memory(self) -> MemoryInterface: def _create_identifier( self, *, - params: Optional[Dict[str, Any]] = None, - children: Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]] = None, + params: Optional[dict[str, Any]] = None, + children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, ) -> ComponentIdentifier: """ Construct the scorer identifier. @@ -120,7 +118,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this scorer. """ - all_params: Dict[str, Any] = { + all_params: dict[str, Any] = { "scorer_type": self.scorer_type, } if params: @@ -250,12 +248,12 @@ def validate_return_scores(self, scores: list[Score]) -> None: async def evaluate_async( self, - file_mapping: Optional["ScorerEvalDatasetFiles"] = None, + file_mapping: Optional[ScorerEvalDatasetFiles] = None, *, num_scorer_trials: int = 3, - update_registry_behavior: "RegistryUpdateBehavior" = None, + update_registry_behavior: RegistryUpdateBehavior = None, max_concurrency: int = 10, - ) -> Optional["ScorerMetrics"]: + ) -> Optional[ScorerMetrics]: """ Evaluate this scorer against human-labeled datasets. @@ -304,7 +302,7 @@ async def evaluate_async( ) @abstractmethod - def get_scorer_metrics(self) -> Optional["ScorerMetrics"]: + def get_scorer_metrics(self) -> Optional[ScorerMetrics]: """ Get evaluation metrics for this scorer from the configured evaluation result file. @@ -612,7 +610,7 @@ async def _score_value_with_llm( # Normalize metadata to a dictionary with string keys and string/int/float values raw_md = parsed_response.get(metadata_output_key) - normalized_md: Optional[Dict[str, Union[str, int, float]]] + normalized_md: Optional[dict[str, Union[str, int, float]]] if raw_md is None: normalized_md = None elif isinstance(raw_md, dict): @@ -680,11 +678,11 @@ async def score_response_async( *, response: Message, objective_scorer: Optional[Scorer] = None, - auxiliary_scorers: Optional[List[Scorer]] = None, + auxiliary_scorers: Optional[list[Scorer]] = None, role_filter: ChatMessageRole = "assistant", objective: Optional[str] = None, skip_on_error_result: bool = True, - ) -> Dict[str, List[Score]]: + ) -> dict[str, list[Score]]: """ Score a response using an objective scorer and optional auxiliary scorers. @@ -704,7 +702,7 @@ async def score_response_async( Raises: ValueError: If response is not provided. """ - result: Dict[str, List[Score]] = {"auxiliary_scores": [], "objective_scores": []} + result: dict[str, list[Score]] = {"auxiliary_scores": [], "objective_scores": []} if not response: raise ValueError("Response must be provided for scoring.") @@ -755,11 +753,11 @@ async def score_response_async( async def score_response_multiple_scorers_async( *, response: Message, - scorers: List[Scorer], + scorers: list[Scorer], role_filter: ChatMessageRole = "assistant", objective: Optional[str] = None, skip_on_error_result: bool = True, - ) -> List[Score]: + ) -> list[Score]: """ Score a response using multiple scorers in parallel. diff --git a/pyrit/score/scorer_evaluation/human_labeled_dataset.py b/pyrit/score/scorer_evaluation/human_labeled_dataset.py index 592e9b562f..b55a3b765e 100644 --- a/pyrit/score/scorer_evaluation/human_labeled_dataset.py +++ b/pyrit/score/scorer_evaluation/human_labeled_dataset.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast import pandas as pd @@ -42,8 +42,8 @@ class HumanLabeledEntry: a different person's score for the same response/conversation. """ - conversation: List[Message] - human_scores: List[Any] + conversation: list[Message] + human_scores: list[Any] def __post_init__(self) -> None: """ @@ -67,7 +67,7 @@ class HarmHumanLabeledEntry(HumanLabeledEntry): string that represents the type of harm (e.g., "hate_speech", "misinformation", etc.). """ - human_scores: List[float] + human_scores: list[float] # For now, this is a string, but may be enum or Literal in the future. harm_category: str @@ -96,7 +96,7 @@ class ObjectiveHumanLabeledEntry(HumanLabeledEntry): string that represents the objective (e.g., "how to make a Molotov cocktail?). """ - human_scores: List[bool] + human_scores: list[bool] objective: str def __post_init__(self) -> None: @@ -123,7 +123,7 @@ def __init__( self, *, name: str, - entries: List[HumanLabeledEntry], + entries: list[HumanLabeledEntry], metrics_type: MetricsType, version: str, harm_definition: Optional[str] = None, @@ -156,7 +156,7 @@ def __init__( self.version = version self.harm_definition = harm_definition self.harm_definition_version = harm_definition_version - self._harm_definition_obj: Optional["HarmDefinition"] = None + self._harm_definition_obj: Optional[HarmDefinition] = None def get_harm_definition(self) -> Optional["HarmDefinition"]: """ @@ -233,7 +233,7 @@ def from_csv( parsed_version = None parsed_harm_definition = None parsed_harm_definition_version = None - with open(csv_path, "r", encoding="utf-8") as f: + with open(csv_path, encoding="utf-8") as f: first_line = f.readline().strip() if first_line.startswith("#"): # Parse key=value pairs from the comment line @@ -294,7 +294,7 @@ def from_csv( else: data_types = ["text"] * len(eval_df[STANDARD_ASSISTANT_RESPONSE_COL]) - entries: List[HumanLabeledEntry] = [] + entries: list[HumanLabeledEntry] = [] for response_to_score, human_scores, objective_or_harm, data_type in zip( responses_to_score, all_human_scores, objectives_or_harms, data_types ): @@ -437,13 +437,13 @@ def _validate_csv_columns(cls, *, eval_df: pd.DataFrame, metrics_type: MetricsTy raise ValueError(f"Human score column '{col}' contains NaN values.") @staticmethod - def _construct_harm_entry(*, messages: List[Message], harm: str, human_scores: List[Any]) -> HarmHumanLabeledEntry: + def _construct_harm_entry(*, messages: list[Message], harm: str, human_scores: list[Any]) -> HarmHumanLabeledEntry: float_scores = [float(score) for score in human_scores] return HarmHumanLabeledEntry(messages, float_scores, harm) @staticmethod def _construct_objective_entry( - *, messages: List[Message], objective: str, human_scores: List[Any] + *, messages: list[Message], objective: str, human_scores: list[Any] ) -> "ObjectiveHumanLabeledEntry": # Convert scores to int before casting to bool in case the values (0, 1) are parsed as strings bool_scores = [bool(int(score)) for score in human_scores] diff --git a/pyrit/score/scorer_evaluation/krippendorff.py b/pyrit/score/scorer_evaluation/krippendorff.py index dde1ec0381..5e9fb5c64e 100644 --- a/pyrit/score/scorer_evaluation/krippendorff.py +++ b/pyrit/score/scorer_evaluation/krippendorff.py @@ -17,10 +17,10 @@ def _validate_and_prepare_data( - reliability_data: "np.ndarray", # type: ignore[type-arg, unused-ignore] + reliability_data: np.ndarray, # type: ignore[type-arg, unused-ignore] level_of_measurement: str, missing: float | None, -) -> tuple["np.ndarray", "np.ndarray", "np.ndarray"]: # type: ignore[type-arg, unused-ignore] +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: # type: ignore[type-arg, unused-ignore] """ Validate inputs and prepare data for reliability calculation. @@ -63,10 +63,10 @@ def _validate_and_prepare_data( def _build_value_counts_matrix( - data: "np.ndarray", # type: ignore[type-arg, unused-ignore] - valid_mask: "np.ndarray", # type: ignore[type-arg, unused-ignore] - categories: "np.ndarray", # type: ignore[type-arg, unused-ignore] -) -> "np.ndarray": # type: ignore[type-arg, unused-ignore] + data: np.ndarray, # type: ignore[type-arg, unused-ignore] + valid_mask: np.ndarray, # type: ignore[type-arg, unused-ignore] + categories: np.ndarray, # type: ignore[type-arg, unused-ignore] +) -> np.ndarray: # type: ignore[type-arg, unused-ignore] """ Build matrix counting how many raters assigned each category to each item. @@ -95,8 +95,8 @@ def _build_value_counts_matrix( def _build_coincidence_matrix( - value_counts: "np.ndarray", # type: ignore[type-arg, unused-ignore] -) -> "np.ndarray": # type: ignore[type-arg, unused-ignore] + value_counts: np.ndarray, # type: ignore[type-arg, unused-ignore] +) -> np.ndarray: # type: ignore[type-arg, unused-ignore] """ Build coincidence matrix from value counts. @@ -132,8 +132,8 @@ def _build_coincidence_matrix( def _build_expected_matrix( - coincidence_matrix: "np.ndarray", # type: ignore[type-arg, unused-ignore] -) -> tuple["np.ndarray", "np.ndarray", float]: # type: ignore[type-arg, unused-ignore] + coincidence_matrix: np.ndarray, # type: ignore[type-arg, unused-ignore] +) -> tuple[np.ndarray, np.ndarray, float]: # type: ignore[type-arg, unused-ignore] """ Build expected coincidence matrix from observed coincidences. @@ -160,8 +160,8 @@ def _build_expected_matrix( def _build_ordinal_distance_matrix( num_categories: int, - n_v: "np.ndarray", # type: ignore[type-arg, unused-ignore] -) -> "np.ndarray": # type: ignore[type-arg, unused-ignore] + n_v: np.ndarray, # type: ignore[type-arg, unused-ignore] +) -> np.ndarray: # type: ignore[type-arg, unused-ignore] """ Build ordinal distance matrix using category marginals. @@ -216,7 +216,7 @@ def _compute_alpha_from_disagreements( def krippendorff_alpha( - reliability_data: "np.ndarray", # type: ignore[type-arg, unused-ignore] # shape: (num_raters_or_trials, num_items); dtype float + reliability_data: np.ndarray, # type: ignore[type-arg, unused-ignore] # shape: (num_raters_or_trials, num_items); dtype float level_of_measurement: str = "ordinal", missing: float | None = np.nan, ) -> float: diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index bee803e5bd..be5c134fd5 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -8,7 +8,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, cast +from typing import Optional, cast import numpy as np from scipy.stats import ttest_1samp @@ -65,7 +65,7 @@ class ScorerEvalDatasetFiles: Required for harm evaluations, ignored for objective evaluations. Defaults to None. """ - human_labeled_datasets_files: List[str] + human_labeled_datasets_files: list[str] result_file: str harm_category: Optional[str] = None @@ -89,7 +89,7 @@ def __init__(self, scorer: Scorer): self.scorer = scorer @classmethod - def from_scorer(cls, scorer: Scorer, metrics_type: Optional[MetricsType] = None) -> "ScorerEvaluator": + def from_scorer(cls, scorer: Scorer, metrics_type: Optional[MetricsType] = None) -> ScorerEvaluator: """ Create a ScorerEvaluator based on the type of scoring. @@ -154,7 +154,7 @@ async def run_evaluation_async( ) # Collect all matching files - csv_files: List[Path] = [] + csv_files: list[Path] = [] for pattern in dataset_files.human_labeled_datasets_files: matched = list(SCORER_EVALS_PATH.glob(pattern)) csv_files.extend(matched) @@ -249,7 +249,7 @@ def _should_skip_evaluation( num_scorer_trials: int, harm_category: Optional[str] = None, result_file_path: Path, - ) -> Tuple[bool, Optional[ScorerMetrics]]: + ) -> tuple[bool, Optional[ScorerMetrics]]: """ Determine whether to skip evaluation based on existing registry entries. @@ -420,7 +420,7 @@ async def _run_evaluation_async( def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> Tuple[List[Message], List[List[float]], Optional[List[str]]]: + ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: """ Validate the dataset and extract data for evaluation. @@ -500,7 +500,7 @@ class HarmScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> Tuple[List[Message], List[List[float]], Optional[List[str]]]: + ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: """ Validate harm dataset and extract evaluation data. @@ -519,8 +519,8 @@ def _validate_and_extract_data( labeled_dataset.validate() - assistant_responses: List[Message] = [] - human_scores_list: List[List[float]] = [] + assistant_responses: list[Message] = [] + human_scores_list: list[list[float]] = [] for entry in labeled_dataset.entries: harm_entry = cast(HarmHumanLabeledEntry, entry) @@ -553,7 +553,7 @@ def _compute_metrics( diff[np.abs(diff) < 1e-10] = 0.0 abs_error = np.abs(diff) - t_statistic, p_value = cast(Tuple[float, float], ttest_1samp(diff, 0)) + t_statistic, p_value = cast(tuple[float, float], ttest_1samp(diff, 0)) num_responses = all_human_scores.shape[1] num_human_raters = all_human_scores.shape[0] @@ -601,7 +601,7 @@ class ObjectiveScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> Tuple[List[Message], List[List[float]], Optional[List[str]]]: + ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: """ Validate objective dataset and extract evaluation data. @@ -619,9 +619,9 @@ def _validate_and_extract_data( labeled_dataset.validate() - assistant_responses: List[Message] = [] - human_scores_list: List[List[float]] = [] - objectives: List[str] = [] + assistant_responses: list[Message] = [] + human_scores_list: list[list[float]] = [] + objectives: list[str] = [] for entry in labeled_dataset.entries: objective_entry = cast(ObjectiveHumanLabeledEntry, entry) diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index c482eb1fb7..560e53654e 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -6,7 +6,7 @@ import json from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import numpy as np @@ -55,7 +55,7 @@ def to_json(self) -> str: return json.dumps(asdict(self)) @classmethod - def from_json(cls: Type[T], file_path: Union[str, Path]) -> T: + def from_json(cls: type[T], file_path: Union[str, Path]) -> T: """ Load the metrics from a JSON file. @@ -69,7 +69,7 @@ def from_json(cls: Type[T], file_path: Union[str, Path]) -> T: FileNotFoundError: If the specified file does not exist. """ file_path = verify_and_resolve_path(file_path) - with open(file_path, "r") as f: + with open(file_path) as f: data = json.load(f) # Extract metrics from nested structure (always under "metrics" key in evaluation result files) @@ -121,9 +121,9 @@ class HarmScorerMetrics(ScorerMetrics): harm_definition_version: Optional[str] = field(default=None, kw_only=True) krippendorff_alpha_humans: Optional[float] = None krippendorff_alpha_model: Optional[float] = None - _harm_definition_obj: Optional["HarmDefinition"] = field(default=None, init=False, repr=False) + _harm_definition_obj: Optional[HarmDefinition] = field(default=None, init=False, repr=False) - def get_harm_definition(self) -> Optional["HarmDefinition"]: + def get_harm_definition(self) -> Optional[HarmDefinition]: """ Load and return the HarmDefinition object for this metrics instance. @@ -192,7 +192,7 @@ class ScorerMetricsWithIdentity(Generic[M]): metrics (M): The evaluation metrics (ObjectiveScorerMetrics or HarmScorerMetrics). """ - scorer_identifier: "ComponentIdentifier" + scorer_identifier: ComponentIdentifier metrics: M def __repr__(self) -> str: diff --git a/pyrit/score/scorer_evaluation/scorer_metrics_io.py b/pyrit/score/scorer_evaluation/scorer_metrics_io.py index 0ecd7804b0..07c9f83bd9 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics_io.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics_io.py @@ -11,7 +11,7 @@ import threading from dataclasses import asdict from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, Optional, TypeVar from pyrit.common.path import ( SCORER_EVALS_PATH, @@ -28,7 +28,7 @@ # Thread locks for writing (module-level, persists for application lifetime) # Locks are created per file path to ensure thread-safe writes -_file_write_locks: Dict[str, threading.Lock] = {} +_file_write_locks: dict[str, threading.Lock] = {} M = TypeVar("M", bound=ScorerMetrics) @@ -43,7 +43,7 @@ def _build_eval_dict( identifier: ComponentIdentifier, *, param_allowlist: Optional[frozenset[str]] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Build a dictionary for eval hashing. @@ -62,7 +62,7 @@ def _build_eval_dict( Returns: Dict[str, Any]: The filtered dictionary suitable for hashing. """ - eval_dict: Dict[str, Any] = { + eval_dict: dict[str, Any] = { ComponentIdentifier.KEY_CLASS_NAME: identifier.class_name, ComponentIdentifier.KEY_CLASS_MODULE: identifier.class_module, } @@ -72,7 +72,7 @@ def _build_eval_dict( eval_dict[key] = value if identifier.children: - eval_children: Dict[str, Any] = {} + eval_children: dict[str, Any] = {} for name in sorted(identifier.children): child_list = identifier.get_child_list(name) if name in _TARGET_CHILD_KEYS: @@ -106,7 +106,7 @@ def compute_eval_hash(identifier: ComponentIdentifier) -> str: return config_hash(_build_eval_dict(identifier)) -def _metrics_to_registry_dict(metrics: ScorerMetrics) -> Dict[str, Any]: +def _metrics_to_registry_dict(metrics: ScorerMetrics) -> dict[str, Any]: """ Convert metrics to a dictionary suitable for registry storage. @@ -127,7 +127,7 @@ def _metrics_to_registry_dict(metrics: ScorerMetrics) -> Dict[str, Any]: def get_all_objective_metrics( file_path: Optional[Path] = None, -) -> List[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: +) -> list[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: """ Load all objective scorer metrics with full scorer identity for comparison. @@ -153,7 +153,7 @@ def get_all_objective_metrics( def get_all_harm_metrics( harm_category: str, -) -> List[ScorerMetricsWithIdentity[HarmScorerMetrics]]: +) -> list[ScorerMetricsWithIdentity[HarmScorerMetrics]]: """ Load all harm scorer metrics for a specific harm category. @@ -176,8 +176,8 @@ def get_all_harm_metrics( def _load_metrics_from_file( *, file_path: Path, - metrics_class: Type[M], -) -> List[ScorerMetricsWithIdentity[M]]: + metrics_class: type[M], +) -> list[ScorerMetricsWithIdentity[M]]: """ Load scorer metrics from a JSONL file with the specified metrics class. @@ -190,7 +190,7 @@ def _load_metrics_from_file( Returns: List[ScorerMetricsWithIdentity[M]]: List of metrics with scorer identity. """ - results: List[ScorerMetricsWithIdentity[M]] = [] + results: list[ScorerMetricsWithIdentity[M]] = [] entries = _load_jsonl(file_path) for entry in entries: @@ -267,7 +267,7 @@ def _find_metrics_by_hash( *, file_path: Path, hash: str, - metrics_class: Type[M], + metrics_class: type[M], ) -> Optional[M]: """ Find scorer metrics by configuration hash in a specific file. @@ -337,7 +337,7 @@ def add_evaluation_results( logger.info(f"Added metrics for {scorer_identifier.class_name} to {file_path.name}") -def _load_jsonl(file_path: Path) -> List[Dict[str, Any]]: +def _load_jsonl(file_path: Path) -> list[dict[str, Any]]: """ Load entries from a JSONL file. @@ -353,7 +353,7 @@ def _load_jsonl(file_path: Path) -> List[Dict[str, Any]]: entries = [] try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: @@ -367,7 +367,7 @@ def _load_jsonl(file_path: Path) -> List[Dict[str, Any]]: return entries -def _append_jsonl_entry(file_path: Path, lock: threading.Lock, entry: Dict[str, Any]) -> None: +def _append_jsonl_entry(file_path: Path, lock: threading.Lock, entry: dict[str, Any]) -> None: """ Append an entry to a JSONL file with thread safety. diff --git a/pyrit/score/scorer_prompt_validator.py b/pyrit/score/scorer_prompt_validator.py index a9043afba3..513b62c7e4 100644 --- a/pyrit/score/scorer_prompt_validator.py +++ b/pyrit/score/scorer_prompt_validator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional, Sequence, get_args +from collections.abc import Sequence +from typing import Optional, get_args from pyrit.models import ChatMessageRole, Message, MessagePiece, PromptDataType diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index acc6a524ae..7102ba3af6 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -3,7 +3,7 @@ import enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union import yaml @@ -95,7 +95,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - def _content_classifier_to_string(self, categories: list[Dict[str, str]]) -> str: + def _content_classifier_to_string(self, categories: list[dict[str, str]]) -> str: """ Convert the content classifier categories to a string representation to be put in a system prompt. diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index fb657d7157..4e7ee59412 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import enum +from collections.abc import Iterator from pathlib import Path -from typing import Any, Iterator, Optional, Union +from typing import Any, Optional, Union import yaml diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index b256bb613f..c66c24d437 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import asyncio -from typing import List, Optional +from typing import Optional from pyrit.identifiers import ComponentIdentifier from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -25,7 +25,7 @@ def __init__( self, *, aggregator: TrueFalseAggregatorFunc, - scorers: List[TrueFalseScorer], + scorers: list[TrueFalseScorer], ) -> None: """ Initialize the composite scorer. diff --git a/pyrit/score/true_false/true_false_score_aggregator.py b/pyrit/score/true_false/true_false_score_aggregator.py index 6a523eea0a..b0b1df7eea 100644 --- a/pyrit/score/true_false/true_false_score_aggregator.py +++ b/pyrit/score/true_false/true_false_score_aggregator.py @@ -3,7 +3,7 @@ import functools import operator -from typing import Callable, Iterable, List +from collections.abc import Callable, Iterable from pyrit.models import Score from pyrit.score.score_aggregator_result import ScoreAggregatorResult @@ -16,7 +16,7 @@ TrueFalseAggregatorFunc = Callable[[Iterable[Score]], ScoreAggregatorResult] -def _build_rationale(scores: List[Score], *, result: bool, true_msg: str, false_msg: str) -> tuple[str, str]: +def _build_rationale(scores: list[Score], *, result: bool, true_msg: str, false_msg: str) -> tuple[str, str]: """ Build description and rationale for aggregated true/false scores. @@ -42,7 +42,7 @@ def _build_rationale(scores: List[Score], *, result: bool, true_msg: str, false_ def _create_aggregator( name: str, *, - result_func: Callable[[List[bool]], bool], + result_func: Callable[[list[bool]], bool], true_msg: str, false_msg: str, ) -> TrueFalseAggregatorFunc: diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 999cf77bd6..09787ed67e 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -9,8 +9,9 @@ """ import pathlib +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.path import DEFAULT_CONFIG_PATH from pyrit.common.yaml_loadable import YamlLoadable @@ -29,10 +30,10 @@ # Type alias for YAML-serializable values that can be passed as initializer args # This matches what YAML can represent: primitives, lists, and nested dicts YamlPrimitive = Union[str, int, float, bool, None] -YamlValue = Union[YamlPrimitive, List["YamlValue"], Dict[str, "YamlValue"]] +YamlValue = Union[YamlPrimitive, list["YamlValue"], dict[str, "YamlValue"]] # Mapping from snake_case config values to internal constants -_MEMORY_DB_TYPE_MAP: Dict[str, str] = { +_MEMORY_DB_TYPE_MAP: dict[str, str] = { "in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL, @@ -50,7 +51,7 @@ class InitializerConfig: """ name: str - args: Optional[Dict[str, YamlValue]] = None + args: Optional[dict[str, YamlValue]] = None @dataclass @@ -90,9 +91,9 @@ class ConfigurationLoader(YamlLoadable): """ memory_db_type: str = "sqlite" - initializers: List[Union[str, Dict[str, Any]]] = field(default_factory=list) - initialization_scripts: Optional[List[str]] = None - env_files: Optional[List[str]] = None + initializers: list[Union[str, dict[str, Any]]] = field(default_factory=list) + initialization_scripts: Optional[list[str]] = None + env_files: Optional[list[str]] = None silent: bool = False def __post_init__(self) -> None: @@ -137,7 +138,7 @@ def _normalize_initializers(self) -> None: Raises: ValueError: If an initializer entry is missing a 'name' field or has an invalid type. """ - normalized: List[InitializerConfig] = [] + normalized: list[InitializerConfig] = [] for entry in self.initializers: if isinstance(entry, str): # Simple string entry: normalize name to snake_case @@ -159,7 +160,7 @@ def _normalize_initializers(self) -> None: self._initializer_configs = normalized @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": + def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": """ Create a ConfigurationLoader from a dictionary. @@ -178,7 +179,7 @@ def load_with_overrides( config_file: Optional[pathlib.Path] = None, *, memory_db_type: Optional[str] = None, - initializers: Optional[Sequence[Union[str, Dict[str, Any]]]] = None, + initializers: Optional[Sequence[Union[str, dict[str, Any]]]] = None, initialization_scripts: Optional[Sequence[str]] = None, env_files: Optional[Sequence[str]] = None, ) -> "ConfigurationLoader": @@ -213,7 +214,7 @@ def load_with_overrides( logger = logging.getLogger(__name__) # Start with defaults - None means "use defaults", [] means "load nothing" - config_data: Dict[str, Any] = { + config_data: dict[str, Any] = { "memory_db_type": "sqlite", "initializers": [], "initialization_scripts": None, # None = use defaults @@ -302,7 +303,7 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: return [] registry = InitializerRegistry() - resolved: List[PyRITInitializer] = [] + resolved: list[PyRITInitializer] = [] for config in self._initializer_configs: initializer_class = registry.get_class(config.name) @@ -338,7 +339,7 @@ def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: if len(self.initialization_scripts) == 0: return [] - resolved: List[pathlib.Path] = [] + resolved: list[pathlib.Path] = [] for script_str in self.initialization_scripts: script_path = pathlib.Path(script_str) if not script_path.is_absolute(): @@ -363,7 +364,7 @@ def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: if len(self.env_files) == 0: return [] - resolved: List[pathlib.Path] = [] + resolved: list[pathlib.Path] = [] for env_str in self.env_files: env_path = pathlib.Path(env_str) if not env_path.is_absolute(): diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 30231c7792..0aff8deafc 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -2,9 +2,10 @@ # Licensed under the MIT license. import logging import pathlib +from collections.abc import Sequence # Import PyRITInitializer for type checking (with TYPE_CHECKING to avoid circular imports) -from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args import dotenv diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index f257e52924..a0d81613df 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -9,7 +9,6 @@ """ import os -from typing import List from pyrit.common.apply_defaults import set_default_value, set_global_variable from pyrit.executor.attack import ( @@ -78,7 +77,7 @@ def description(self) -> str: ) @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """Get list of required environment variables.""" return [ "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index f421c53c6e..be42a8c173 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -15,7 +15,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, List, Optional, Type +from typing import Any, Optional from pyrit.prompt_target import ( AzureMLChatTarget, @@ -40,7 +40,7 @@ class TargetConfig: """Configuration for a target to be registered.""" registry_name: str - target_class: Type[PromptTarget] + target_class: type[PromptTarget] endpoint_var: str key_var: str = "" # Empty string means no auth required model_var: Optional[str] = None @@ -50,7 +50,7 @@ class TargetConfig: # Define all supported target configurations. # Only PRIMARY configurations are included here - alias configurations that use ${...} # syntax in .env_example are excluded since they reference other primary configurations. -TARGET_CONFIGS: List[TargetConfig] = [ +TARGET_CONFIGS: list[TargetConfig] = [ # ============================================ # OpenAI Chat Targets (OpenAIChatTarget) # ============================================ @@ -360,7 +360,7 @@ def description(self) -> str: ) @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """ Get list of required environment variables. diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index 45ec72a57b..109e5a8c01 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -10,8 +10,9 @@ import sys from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, Dict, Iterator, List +from typing import Any from pyrit.common.apply_defaults import get_global_default_values @@ -58,7 +59,7 @@ def description(self) -> str: return self.name @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """ Get list of required environment variables for this initializer. @@ -132,7 +133,7 @@ async def initialize_with_tracking_async(self) -> None: await self.initialize_async() @contextmanager - def _track_initialization_changes(self) -> Iterator[Dict[str, Any]]: + def _track_initialization_changes(self) -> Iterator[dict[str, Any]]: """ Context manager to track what changes during initialization. @@ -145,7 +146,7 @@ def _track_initialization_changes(self) -> Iterator[Dict[str, Any]]: current_main_dict = dict(sys.modules["__main__"].__dict__) # Initialize tracking dict - tracking_info: Dict[str, List[str]] = {"default_values": [], "global_variables": []} + tracking_info: dict[str, list[str]] = {"default_values": [], "global_variables": []} try: yield tracking_info @@ -166,7 +167,7 @@ def _track_initialization_changes(self) -> Iterator[Dict[str, Any]]: if name not in current_main_dict and name not in tracking_info["global_variables"]: tracking_info["global_variables"].append(name) - async def get_dynamic_default_values_info_async(self) -> Dict[str, Any]: + async def get_dynamic_default_values_info_async(self) -> dict[str, Any]: """ Get information about what default values and global variables this initializer sets. This is useful for debugging what default_values are set by an initializer. @@ -246,7 +247,7 @@ async def get_dynamic_default_values_info_async(self) -> Dict[str, Any]: sys.modules["__main__"].__dict__[var_name] = value @classmethod - async def get_info_async(cls) -> Dict[str, Any]: + async def get_info_async(cls) -> dict[str, Any]: """ Get information about this initializer class. diff --git a/pyrit/setup/initializers/scenarios/load_default_datasets.py b/pyrit/setup/initializers/scenarios/load_default_datasets.py index cca17dba6c..0055736372 100644 --- a/pyrit/setup/initializers/scenarios/load_default_datasets.py +++ b/pyrit/setup/initializers/scenarios/load_default_datasets.py @@ -10,7 +10,6 @@ import logging import textwrap -from typing import List from pyrit.datasets import SeedDatasetProvider from pyrit.memory import CentralMemory @@ -47,7 +46,7 @@ def description(self) -> str: ).strip() @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """Return the list of required environment variables.""" return [] @@ -57,7 +56,7 @@ async def initialize_async(self) -> None: registry = ScenarioRegistry.get_registry_singleton() # Collect all default datasets from all scenarios - all_default_datasets: List[str] = [] + all_default_datasets: list[str] = [] # Get all scenario names from registry scenario_names = registry.get_names() diff --git a/pyrit/setup/initializers/scenarios/objective_list.py b/pyrit/setup/initializers/scenarios/objective_list.py index da7572d343..a07ca9024c 100644 --- a/pyrit/setup/initializers/scenarios/objective_list.py +++ b/pyrit/setup/initializers/scenarios/objective_list.py @@ -10,8 +10,6 @@ should prefer using dataset_config in initialize_async for more flexibility. """ -from typing import List - from pyrit.common.apply_defaults import set_default_value from pyrit.scenario import Scenario from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer @@ -31,7 +29,7 @@ def execution_order(self) -> int: return 10 @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """Return an empty list because this initializer requires no environment variables.""" return [] diff --git a/pyrit/setup/initializers/scenarios/openai_objective_target.py b/pyrit/setup/initializers/scenarios/openai_objective_target.py index 6d4aa84082..6fc6ae4315 100644 --- a/pyrit/setup/initializers/scenarios/openai_objective_target.py +++ b/pyrit/setup/initializers/scenarios/openai_objective_target.py @@ -11,7 +11,6 @@ """ import os -from typing import List from pyrit.common.apply_defaults import set_default_value from pyrit.prompt_target import OpenAIChatTarget @@ -42,7 +41,7 @@ def description(self) -> str: ) @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """Get list of required environment variables.""" return [ "DEFAULT_OPENAI_FRONTEND_ENDPOINT", diff --git a/pyrit/setup/initializers/simple.py b/pyrit/setup/initializers/simple.py index f63dd50d8e..b8071bf6f9 100644 --- a/pyrit/setup/initializers/simple.py +++ b/pyrit/setup/initializers/simple.py @@ -8,8 +8,6 @@ simple configuration including converters, scorers, and targets using basic OpenAI. """ -from typing import List - from pyrit.common.apply_defaults import set_default_value, set_global_variable from pyrit.executor.attack import ( AttackAdversarialConfig, @@ -73,7 +71,7 @@ def description(self) -> str: ) @property - def required_env_vars(self) -> List[str]: + def required_env_vars(self) -> list[str]: """Get list of required environment variables.""" return [ "OPENAI_CHAT_ENDPOINT", diff --git a/pyrit/show_versions.py b/pyrit/show_versions.py index de68e06456..e19fde71ff 100644 --- a/pyrit/show_versions.py +++ b/pyrit/show_versions.py @@ -75,8 +75,8 @@ def show_versions() -> None: print("\nSystem:") for k, stat in sys_info.items(): - print("{k:>10}: {stat}".format(k=k, stat=stat)) + print(f"{k:>10}: {stat}") print("\nPython dependencies:") for k, stat in deps_info.items(): - print("{k:>13}: {stat}".format(k=k, stat=stat)) + print(f"{k:>13}: {stat}") diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index ec0bf7cdb3..7701e8d269 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -3,8 +3,9 @@ import logging import time +from collections.abc import Callable from threading import Semaphore, Thread -from typing import Any, Callable, Optional +from typing import Any, Optional from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models import MessagePiece, Score diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index e7ee001008..7a4312620f 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -3,8 +3,9 @@ import socket import time +from collections.abc import Callable from threading import Event, Semaphore, Thread -from typing import Any, Callable, Optional +from typing import Any, Optional import rpyc diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d271668e49..f3574a2b88 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,7 +4,7 @@ import asyncio import os import tempfile -from typing import Generator +from collections.abc import Generator from unittest.mock import patch import pytest diff --git a/tests/integration/memory/test_azure_sql_memory_integration.py b/tests/integration/memory/test_azure_sql_memory_integration.py index 2f500855e9..a50aad7ff5 100644 --- a/tests/integration/memory/test_azure_sql_memory_integration.py +++ b/tests/integration/memory/test_azure_sql_memory_integration.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from collections.abc import Generator from contextlib import closing, contextmanager from datetime import datetime, timedelta -from typing import Generator, List from uuid import uuid4 import numpy as np @@ -58,7 +58,7 @@ def get_test_scorer_identifier(**kwargs) -> ScorerIdentifier: @contextmanager -def cleanup_conversation_data(memory: AzureSQLMemory, conversation_ids: List[str]) -> Generator[None, None, None]: +def cleanup_conversation_data(memory: AzureSQLMemory, conversation_ids: list[str]) -> Generator[None, None, None]: """ Context manager to ensure cleanup of test data from attack results and message pieces. diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index c924fca504..a15db00b66 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Generator, Optional +from collections.abc import Generator +from typing import Optional from sqlalchemy import inspect diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index 2bd33946a7..9f7fdef20b 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import os -from typing import Generator +from collections.abc import Generator from unittest.mock import patch import pytest diff --git a/tests/integration/score/test_hitl_gradio_integration.py b/tests/integration/score/test_hitl_gradio_integration.py index da3898f2b9..123d467d7d 100644 --- a/tests/integration/score/test_hitl_gradio_integration.py +++ b/tests/integration/score/test_hitl_gradio_integration.py @@ -3,8 +3,9 @@ import importlib.util import time +from collections.abc import Callable from threading import Event, Thread -from typing import Callable, Optional +from typing import Optional from unittest.mock import MagicMock, patch import pytest diff --git a/tests/unit/analytics/test_conversation_analytics.py b/tests/unit/analytics/test_conversation_analytics.py index a2525c7be7..02f4534a2f 100644 --- a/tests/unit/analytics/test_conversation_analytics.py +++ b/tests/unit/analytics/test_conversation_analytics.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Sequence +from collections.abc import Sequence from unittest.mock import MagicMock import pytest diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2cf095920c..8ed3e87e5b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -3,7 +3,7 @@ import os import tempfile -from typing import Generator +from collections.abc import Generator from unittest.mock import patch import pytest diff --git a/tests/unit/converter/test_azure_speech_text_converter.py b/tests/unit/converter/test_azure_speech_text_converter.py index 86730c9561..21fae8ca87 100644 --- a/tests/unit/converter/test_azure_speech_text_converter.py +++ b/tests/unit/converter/test_azure_speech_text_converter.py @@ -56,10 +56,8 @@ def test_stop_cb(self, mock_logger, MockSpeechRecognizer, mock_get_required_valu # Check if the callback function worked as expected MockSpeechRecognizer.stop_continuous_recognition_async.assert_called_once() - mock_logger.info.assert_any_call("CLOSING on {}".format(mock_event)) - mock_logger.info.assert_any_call( - "Speech recognition canceled: {}".format(speechsdk.CancellationReason.EndOfStream) - ) + mock_logger.info.assert_any_call(f"CLOSING on {mock_event}") + mock_logger.info.assert_any_call(f"Speech recognition canceled: {speechsdk.CancellationReason.EndOfStream}") mock_logger.info.assert_called_with("End of audio stream detected.") @patch( @@ -86,7 +84,7 @@ def test_transcript_cb(self, mock_logger, MockSpeechRecognizer, mock_get_require converter.transcript_cb(evt=mock_event, transcript=transcript) # Check if the callback function worked as expected - mock_logger.info.assert_called_once_with("RECOGNIZED: {}".format(mock_event.result.text)) + mock_logger.info.assert_called_once_with(f"RECOGNIZED: {mock_event.result.text}") assert mock_event.result.text in transcript @patch( diff --git a/tests/unit/converter/test_first_letter_converter.py b/tests/unit/converter/test_first_letter_converter.py index f7d1ba8a36..1f7a09d350 100644 --- a/tests/unit/converter/test_first_letter_converter.py +++ b/tests/unit/converter/test_first_letter_converter.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index c888a960be..ff155ec9a2 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -61,7 +61,7 @@ def test_write_cache_json(self, tmp_path): loader._write_cache(cache_file=cache_file, examples=data, file_type="json") assert cache_file.exists() - with open(cache_file, "r", encoding="utf-8") as f: + with open(cache_file, encoding="utf-8") as f: loaded = json.load(f) assert loaded == data diff --git a/tests/unit/docs/test_api_documentation.py b/tests/unit/docs/test_api_documentation.py index 48db273589..e5e88c4860 100644 --- a/tests/unit/docs/test_api_documentation.py +++ b/tests/unit/docs/test_api_documentation.py @@ -11,7 +11,6 @@ import importlib import re from pathlib import Path -from typing import Set import pytest @@ -25,7 +24,7 @@ def get_api_rst_path() -> Path: return workspace_root / "doc" / "api.rst" -def get_documented_items_from_rst() -> dict[str, Set[str]]: +def get_documented_items_from_rst() -> dict[str, set[str]]: """ Parse api.rst and extract all documented items by module. @@ -35,7 +34,7 @@ def get_documented_items_from_rst() -> dict[str, Set[str]]: api_rst = get_api_rst_path() content = api_rst.read_text() - documented: dict[str, Set[str]] = {} + documented: dict[str, set[str]] = {} current_module = None # Find module definitions like :py:mod:`pyrit.prompt_converter` @@ -75,7 +74,7 @@ def get_documented_items_from_rst() -> dict[str, Set[str]]: return documented -def get_module_exports(module_path: str) -> Set[str]: +def get_module_exports(module_path: str) -> set[str]: """ Get all exported items from a module's __all__ list. diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index ce68cb9d65..c86e741e9c 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -18,7 +18,7 @@ """ import uuid -from typing import List, Optional +from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -133,7 +133,7 @@ def sample_system_piece() -> MessagePiece: @pytest.fixture -def sample_conversation(sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece) -> List[Message]: +def sample_conversation(sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece) -> list[Message]: """Create a sample conversation with user and assistant messages.""" return [ Message(message_pieces=[sample_user_piece]), @@ -509,7 +509,7 @@ def test_get_conversation_returns_empty_list_when_no_messages(self, attack_ident assert result == [] def test_get_conversation_returns_messages_in_order( - self, attack_identifier: ComponentIdentifier, sample_conversation: List[Message] + self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_conversation returns messages in order.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -537,7 +537,7 @@ def test_get_last_message_returns_none_for_empty_conversation(self, attack_ident assert result is None def test_get_last_message_returns_last_piece( - self, attack_identifier: ComponentIdentifier, sample_conversation: List[Message] + self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns the most recent message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -555,7 +555,7 @@ def test_get_last_message_returns_last_piece( assert result.api_role == "assistant" def test_get_last_message_with_role_filter( - self, attack_identifier: ComponentIdentifier, sample_conversation: List[Message] + self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message with role filter returns correct message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -574,7 +574,7 @@ def test_get_last_message_with_role_filter( assert result.api_role == "user" def test_get_last_message_with_role_filter_returns_none_when_no_match( - self, attack_identifier: ComponentIdentifier, sample_conversation: List[Message] + self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns None when no message matches role filter.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -718,7 +718,7 @@ async def test_adds_prepended_conversation_to_memory_for_chat_target( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that prepended conversation is added to memory for chat targets.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -766,7 +766,7 @@ async def test_normalizes_for_non_chat_target_by_default( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that prepended conversation is normalized for non-chat targets by default.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -792,7 +792,7 @@ async def test_normalizes_for_non_chat_target_when_configured( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that non-chat target normalizes prepended conversation when configured.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -821,7 +821,7 @@ async def test_returns_turn_count_for_multi_turn_attacks( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that turn count is returned for multi-turn attacks.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1022,7 +1022,7 @@ async def test_non_chat_target_behavior_normalize_is_default( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that non-chat targets normalize by default (no config), matching dataclass field default.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1048,7 +1048,7 @@ async def test_non_chat_target_behavior_raise_explicit( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that non_chat_target_behavior='raise' raises ValueError.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1073,7 +1073,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_creates_next_messag self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn creates next_message when none exists.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1101,7 +1101,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existin self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn prepends context to existing next_message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1131,7 +1131,7 @@ async def test_non_chat_target_behavior_normalize_returns_empty_state( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn returns empty ConversationState (no turn tracking).""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1161,7 +1161,7 @@ async def test_apply_converters_to_roles_default_applies_to_all( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that converters are applied to all roles by default.""" mock_normalizer = MagicMock(spec=PromptNormalizer) @@ -1188,7 +1188,7 @@ async def test_apply_converters_to_roles_user_only( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that converters are applied only to user role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) @@ -1217,7 +1217,7 @@ async def test_apply_converters_to_roles_assistant_only( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that converters are applied only to assistant role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) @@ -1246,7 +1246,7 @@ async def test_apply_converters_to_roles_empty_list_skips_all( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that empty roles list means no converters applied to any role.""" mock_normalizer = MagicMock(spec=PromptNormalizer) @@ -1279,7 +1279,7 @@ async def test_message_normalizer_default_uses_conversation_context_normalizer( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that default normalizer produces Turn N format.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1307,7 +1307,7 @@ async def test_message_normalizer_custom_normalizer_is_used( self, attack_identifier: ComponentIdentifier, mock_prompt_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that custom message_normalizer is used when provided.""" from pyrit.message_normalizer import MessageStringNormalizer @@ -1388,7 +1388,7 @@ async def test_chat_target_ignores_non_chat_target_behavior( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that chat targets ignore non_chat_target_behavior setting.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1469,7 +1469,7 @@ class TestAddPrependedConversationToMemory: async def test_adds_messages_to_memory( self, attack_identifier: ComponentIdentifier, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that messages are added to memory.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1488,7 +1488,7 @@ async def test_adds_messages_to_memory( async def test_assigns_conversation_id_to_all_pieces( self, attack_identifier: ComponentIdentifier, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that conversation_id is assigned to all message pieces.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1508,7 +1508,7 @@ async def test_assigns_conversation_id_to_all_pieces( async def test_assigns_attack_identifier_to_all_pieces( self, attack_identifier: ComponentIdentifier, - sample_conversation: List[Message], + sample_conversation: list[Message], ) -> None: """Test that attack_identifier is assigned to all message pieces.""" manager = ConversationManager(attack_identifier=attack_identifier) diff --git a/tests/unit/executor/attack/multi_turn/test_red_team_system.py b/tests/unit/executor/attack/multi_turn/test_red_team_system.py index 531a0b8e87..be7abe00dc 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_team_system.py +++ b/tests/unit/executor/attack/multi_turn/test_red_team_system.py @@ -7,7 +7,7 @@ def test_system_prompt_from_file(): strategy_path = RTASystemPromptPaths.TEXT_GENERATION.value - with open(strategy_path, "r") as strategy_file: + with open(strategy_path) as strategy_file: strategy = strategy_file.read() string_before_template = "value: |\n " strategy_template = strategy[strategy.find(string_before_template) + len(string_before_template) :] diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index cf1f1073ca..ef49184977 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -6,7 +6,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -55,7 +55,7 @@ class NodeMockConfig: completed: bool = True off_topic: bool = False objective_score_value: Optional[float] = None - auxiliary_scores: Dict[str, float] = field(default_factory=dict) + auxiliary_scores: dict[str, float] = field(default_factory=dict) objective_target_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) adversarial_chat_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -135,7 +135,7 @@ def duplicate_side_effect(): return node @staticmethod - def create_nodes_with_scores(scores: List[float]) -> List[_TreeOfAttacksNode]: + def create_nodes_with_scores(scores: list[float]) -> list[_TreeOfAttacksNode]: """Create multiple nodes with the given objective scores.""" return [ MockNodeFactory.create_node(NodeMockConfig(node_id=f"node_{i}", objective_score_value=score)) @@ -150,8 +150,8 @@ def __init__(self) -> None: self.objective_target: Optional[PromptTarget] = None self.adversarial_chat: Optional[PromptChatTarget] = None self.objective_scorer: Optional[Scorer] = None - self.auxiliary_scorers: List[Scorer] = [] - self.tree_params: Dict[str, Any] = {} + self.auxiliary_scorers: list[Scorer] = [] + self.tree_params: dict[str, Any] = {} self.converters: Optional[AttackConverterConfig] = None self.successful_threshold: float = 0.8 self.prompt_normalizer: Optional[PromptNormalizer] = None @@ -378,7 +378,7 @@ def create_threshold_score(*, original_float_value: float, threshold: float = 0. ) @staticmethod - def add_nodes_to_tree(context: TAPAttackContext, nodes: List[_TreeOfAttacksNode], parent: str = "root"): + def add_nodes_to_tree(context: TAPAttackContext, nodes: list[_TreeOfAttacksNode], parent: str = "root"): """Add nodes to the context's tree visualization.""" for i, node in enumerate(nodes): score_str = "" diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index a72772c70e..673a87cf26 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -9,7 +9,7 @@ """ import uuid -from typing import List, Optional +from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -107,7 +107,7 @@ def multimodal_audio_message() -> Message: @pytest.fixture -def prepended_conversation_text() -> List[Message]: +def prepended_conversation_text() -> list[Message]: """Create a text-only prepended conversation.""" return [ Message.from_prompt(prompt="Hello, I need help with something.", role="user"), @@ -118,7 +118,7 @@ def prepended_conversation_text() -> List[Message]: @pytest.fixture -def prepended_conversation_multimodal() -> List[Message]: +def prepended_conversation_multimodal() -> list[Message]: """Create a multimodal prepended conversation with image content.""" conv_id = str(uuid.uuid4()) return [ @@ -572,7 +572,7 @@ class TestPrependedConversationInMemory: def _assert_assistant_translated_to_simulated( self, *, - conversation: List[Message], + conversation: list[Message], prepended_count: int, ) -> None: """ @@ -606,7 +606,7 @@ async def test_prompt_sending_attack_adds_prepended_to_memory( self, mock_chat_target: MagicMock, sample_response: Message, - prepended_conversation_multimodal: List[Message], + prepended_conversation_multimodal: list[Message], sqlite_instance, ) -> None: """Test that prepended conversation is preserved in memory with correct role translation.""" @@ -649,7 +649,7 @@ async def test_prompt_sending_attack_adds_prepended_to_memory( async def test_red_teaming_attack_adds_prepended_to_memory( self, red_teaming_attack: RedTeamingAttack, - prepended_conversation_multimodal: List[Message], + prepended_conversation_multimodal: list[Message], sqlite_instance, ) -> None: """Test that RedTeamingAttack preserves prepended conversation in memory with role translation.""" @@ -683,7 +683,7 @@ async def test_red_teaming_attack_adds_prepended_to_memory( async def test_crescendo_attack_adds_prepended_to_memory( self, crescendo_attack: CrescendoAttack, - prepended_conversation_multimodal: List[Message], + prepended_conversation_multimodal: list[Message], multimodal_text_message: Message, sqlite_instance, ) -> None: @@ -723,7 +723,7 @@ async def test_tap_attack_adds_prepended_to_memory( mock_objective_scorer: MagicMock, sample_response: Message, success_score: Score, - prepended_conversation_multimodal: List[Message], + prepended_conversation_multimodal: list[Message], multimodal_text_message: Message, sqlite_instance, ) -> None: @@ -801,7 +801,7 @@ class TestMultiTurnTurnCounting: async def test_red_teaming_starts_with_prepended_turn_count( self, red_teaming_attack: RedTeamingAttack, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], ) -> None: """Test that RedTeamingAttack starts executed_turns at prepended turn count.""" # The prepended_conversation_text has 2 assistant messages @@ -818,7 +818,7 @@ async def test_red_teaming_starts_with_prepended_turn_count( async def test_crescendo_starts_with_prepended_turn_count( self, crescendo_attack: CrescendoAttack, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], multimodal_text_message: Message, ) -> None: """Test that CrescendoAttack starts executed_turns at prepended turn count.""" @@ -836,7 +836,7 @@ async def test_crescendo_starts_with_prepended_turn_count( async def test_tap_starts_with_prepended_turn_count( self, tap_attack: TreeOfAttacksWithPruningAttack, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], multimodal_text_message: Message, ) -> None: """Test that TreeOfAttacksWithPruningAttack starts executed_turns at prepended turn count.""" @@ -894,7 +894,7 @@ async def test_prompt_sending_attack_propagates_memory_labels( # ============================================================================= -def _get_adversarial_chat_text_values(*, adversarial_chat_conversation_id: str) -> List[str]: +def _get_adversarial_chat_text_values(*, adversarial_chat_conversation_id: str) -> list[str]: """ Get all text values from the adversarial chat conversation in memory. @@ -920,7 +920,7 @@ def _get_adversarial_chat_text_values(*, adversarial_chat_conversation_id: str) def _assert_prepended_text_in_adversarial_context( *, - prepended_conversation: List[Message], + prepended_conversation: list[Message], adversarial_chat_conversation_id: str, adversarial_chat_mock: Optional[MagicMock] = None, ) -> None: @@ -981,7 +981,7 @@ async def test_red_teaming_injects_prepended_into_adversarial_context( self, red_teaming_attack: RedTeamingAttack, mock_adversarial_chat: MagicMock, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], sqlite_instance, ) -> None: """Test that RedTeamingAttack injects prepended conversation into adversarial chat context.""" @@ -1007,7 +1007,7 @@ async def test_crescendo_injects_prepended_into_adversarial_context( self, crescendo_attack: CrescendoAttack, mock_adversarial_chat: MagicMock, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], multimodal_text_message: Message, sqlite_instance, ) -> None: @@ -1035,7 +1035,7 @@ async def test_tap_injects_prepended_into_adversarial_context( self, tap_attack: TreeOfAttacksWithPruningAttack, mock_adversarial_chat: MagicMock, - prepended_conversation_text: List[Message], + prepended_conversation_text: list[Message], multimodal_text_message: Message, sqlite_instance, ) -> None: diff --git a/tests/unit/executor/benchmark/test_fairness_bias.py b/tests/unit/executor/benchmark/test_fairness_bias.py index a9896a7827..b33a20da78 100644 --- a/tests/unit/executor/benchmark/test_fairness_bias.py +++ b/tests/unit/executor/benchmark/test_fairness_bias.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -86,7 +85,7 @@ def sample_attack_result() -> AttackResult: @pytest.fixture -def mock_conversation_pieces() -> List[Message]: +def mock_conversation_pieces() -> list[Message]: """Mock conversation pieces for memory testing.""" return [ Message( @@ -259,7 +258,7 @@ async def test_perform_async_calls_prompt_sending_attack( mock_prompt_target: MagicMock, sample_benchmark_context: FairnessBiasBenchmarkContext, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test that perform_async calls the underlying PromptSendingAttack.""" with patch("pyrit.executor.benchmark.fairness_bias.PromptSendingAttack") as mock_attack_class: @@ -395,7 +394,7 @@ async def test_execute_async_with_required_parameters( self, mock_prompt_target: MagicMock, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test execute_async with only required parameters.""" with patch("pyrit.executor.benchmark.fairness_bias.PromptSendingAttack") as mock_attack_class: @@ -421,11 +420,11 @@ async def test_execute_async_with_optional_parameters( self, mock_prompt_target: MagicMock, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test execute_async with optional parameters.""" - prepended_conversation: List[Message] = [] - memory_labels: Dict[str, str] = {"test": "label"} + prepended_conversation: list[Message] = [] + memory_labels: dict[str, str] = {"test": "label"} custom_objective = "Custom story objective" with patch("pyrit.executor.benchmark.fairness_bias.PromptSendingAttack") as mock_attack_class: @@ -464,7 +463,7 @@ async def test_execute_async_multiple_experiments( self, mock_prompt_target: MagicMock, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test execute_async with multiple experiments.""" with patch("pyrit.executor.benchmark.fairness_bias.PromptSendingAttack") as mock_attack_class: @@ -501,7 +500,7 @@ async def test_full_benchmark_workflow( self, mock_prompt_target: MagicMock, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test full benchmark workflow from start to finish.""" with patch("pyrit.executor.benchmark.fairness_bias.PromptSendingAttack") as mock_attack_class: @@ -542,7 +541,7 @@ async def test_benchmark_with_memory_labels( self, mock_prompt_target: MagicMock, sample_attack_result: AttackResult, - mock_conversation_pieces: List[Message], + mock_conversation_pieces: list[Message], ) -> None: """Test benchmark execution with memory labels.""" memory_labels = {"experiment_type": "fairness_test", "model": "test_model"} diff --git a/tests/unit/executor/benchmark/test_question_answering.py b/tests/unit/executor/benchmark/test_question_answering.py index 6a9c653d1f..c14a89c668 100644 --- a/tests/unit/executor/benchmark/test_question_answering.py +++ b/tests/unit/executor/benchmark/test_question_answering.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -358,8 +357,8 @@ async def test_execute_async_with_optional_parameters( sample_attack_result: AttackResult, ) -> None: """Test execute_async with optional parameters.""" - prepended_conversation: List[Message] = [] - memory_labels: Dict[str, str] = {"test": "label"} + prepended_conversation: list[Message] = [] + memory_labels: dict[str, str] = {"test": "label"} with patch("pyrit.executor.benchmark.question_answering.PromptSendingAttack") as mock_attack_class: mock_attack_instance = AsyncMock() @@ -409,7 +408,7 @@ async def test_context_with_prepended_conversation( mock_response = MagicMock(spec=Message) mock_response.message_pieces = [mock_message_piece] - prepended_conversation: List[Message] = [mock_response] + prepended_conversation: list[Message] = [mock_response] context = QuestionAnsweringBenchmarkContext( question_answering_entry=sample_question_entry, prepended_conversation=prepended_conversation diff --git a/tests/unit/executor/promptgen/test_anecdoctor.py b/tests/unit/executor/promptgen/test_anecdoctor.py index 809016380b..31d4667cad 100644 --- a/tests/unit/executor/promptgen/test_anecdoctor.py +++ b/tests/unit/executor/promptgen/test_anecdoctor.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -54,7 +53,7 @@ def mock_prompt_normalizer() -> PromptNormalizer: @pytest.fixture -def sample_evaluation_data() -> List[str]: +def sample_evaluation_data() -> list[str]: """Sample evaluation data for testing.""" return [ "Claim: The earth is flat. Review: FALSE", diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 71c7670899..8fadfb013d 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -3,7 +3,8 @@ import uuid -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from pyrit.common.utils import to_sha256 from pyrit.identifiers import ComponentIdentifier diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 8fe3b2fa50..14064b718d 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -3,8 +3,8 @@ import os import tempfile +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from unittest.mock import MagicMock, patch from pyrit.common.path import DB_DATA_PATH @@ -105,7 +105,7 @@ def test_export_all_conversations_with_scores_correct_data(sqlite_instance: Memo # Read and verify the exported JSON content import json - with open(file_path, "r") as f: + with open(file_path) as f: exported_data = json.load(f) assert len(exported_data) == 1 @@ -143,7 +143,7 @@ def test_export_all_conversations_with_scores_empty_data(sqlite_instance: Memory # Read and verify the exported JSON content is empty import json - with open(file_path, "r") as f: + with open(file_path) as f: exported_data = json.load(f) assert exported_data == [] diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 65ac29420a..54835a18af 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -3,8 +3,8 @@ import uuid +from collections.abc import MutableSequence, Sequence from datetime import datetime -from typing import MutableSequence, Sequence from unittest.mock import MagicMock, patch from uuid import uuid4 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 66c98c7528..6087af1418 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -3,7 +3,8 @@ import uuid -from typing import Literal, Sequence +from collections.abc import Sequence +from typing import Literal from uuid import uuid4 import pytest diff --git a/tests/unit/memory/memory_interface/test_interface_seed_prompts.py b/tests/unit/memory/memory_interface/test_interface_seed_prompts.py index 84e12ce3cb..662dfb8656 100644 --- a/tests/unit/memory/memory_interface/test_interface_seed_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_seed_prompts.py @@ -3,7 +3,7 @@ import os import tempfile -from typing import Sequence +from collections.abc import Sequence from unittest.mock import patch from uuid import uuid4 diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index a9e4db4f9b..6001c2756a 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -3,7 +3,7 @@ import os import uuid -from typing import Generator, MutableSequence, Sequence +from collections.abc import Generator, MutableSequence, Sequence import pytest diff --git a/tests/unit/memory/test_memory_embedding.py b/tests/unit/memory/test_memory_embedding.py index f1a6a4e928..7fe1c1d4b6 100644 --- a/tests/unit/memory/test_memory_embedding.py +++ b/tests/unit/memory/test_memory_embedding.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import MutableSequence, Sequence +from collections.abc import MutableSequence, Sequence import pytest diff --git a/tests/unit/memory/test_memory_exporter.py b/tests/unit/memory/test_memory_exporter.py index d9d247292e..91cc79de5c 100644 --- a/tests/unit/memory/test_memory_exporter.py +++ b/tests/unit/memory/test_memory_exporter.py @@ -3,7 +3,7 @@ import csv import json -from typing import Sequence +from collections.abc import Sequence import pytest from sqlalchemy.inspection import inspect @@ -26,10 +26,10 @@ def model_to_dict(instance): def read_file(file_path, export_type): if export_type == "json": - with open(file_path, "r") as f: + with open(file_path) as f: return json.load(f) elif export_type == "csv": - with open(file_path, "r", newline="") as f: + with open(file_path, newline="") as f: reader = csv.DictReader(f) return list(reader) else: diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index eb345bb2f6..f99b725258 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -3,7 +3,7 @@ import os import uuid -from typing import Sequence +from collections.abc import Sequence from unittest.mock import MagicMock import pytest diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index d61645503f..7da46c5fee 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -5,8 +5,9 @@ import shutil import tempfile import uuid +from collections.abc import Generator, MutableSequence, Sequence from contextlib import AbstractAsyncContextManager -from typing import Generator, MutableSequence, Optional, Sequence +from typing import Optional from unittest.mock import MagicMock, patch from pyrit.identifiers import ComponentIdentifier diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index c0672d07a0..469c94e87b 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -5,8 +5,8 @@ import tempfile import time import uuid +from collections.abc import MutableSequence from datetime import datetime, timedelta -from typing import MutableSequence import pytest from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations diff --git a/tests/unit/scenarios/test_cyber.py b/tests/unit/scenarios/test_cyber.py index 736ed03158..749d05d27b 100644 --- a/tests/unit/scenarios/test_cyber.py +++ b/tests/unit/scenarios/test_cyber.py @@ -4,7 +4,6 @@ """Tests for the Cyber class.""" import pathlib -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -112,7 +111,7 @@ def mock_adversarial_target(): @pytest.fixture -def sample_objectives() -> List[str]: +def sample_objectives() -> list[str]: """Create sample objectives for testing.""" return ["test prompt 1", "test prompt 2"] diff --git a/tests/unit/scenarios/test_jailbreak.py b/tests/unit/scenarios/test_jailbreak.py index 675fe3b73e..3e91d3c94e 100644 --- a/tests/unit/scenarios/test_jailbreak.py +++ b/tests/unit/scenarios/test_jailbreak.py @@ -3,7 +3,6 @@ """Tests for the Jailbreak class.""" -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -21,7 +20,7 @@ @pytest.fixture -def mock_templates() -> List[str]: +def mock_templates() -> list[str]: """Mock constant for jailbreak subset.""" return ["aim", "dan_1", "tuo"] @@ -44,7 +43,7 @@ def mock_scenario_result_id() -> str: @pytest.fixture -def mock_memory_seed_groups() -> List[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" return [ SeedGroup(seeds=[SeedObjective(value=prompt)]) @@ -384,7 +383,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], ) -> None: """Test initialization with custom max_concurrency.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -398,7 +397,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], ) -> None: """Test initialization with memory labels.""" memory_labels = {"type": "jailbreak", "category": "scenario"} @@ -435,7 +434,7 @@ def test_scenario_default_dataset(self) -> None: @pytest.mark.asyncio async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: List[SeedGroup] + self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup] ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -478,7 +477,7 @@ async def test_roleplay_attacks_share_adversarial_target( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], roleplay_jailbreak_strategy: JailbreakStrategy, ) -> None: """Test that multiple role-play attacks share the same adversarial target instance.""" diff --git a/tests/unit/scenarios/test_leakage_scenario.py b/tests/unit/scenarios/test_leakage_scenario.py index 357ea296dd..86773e002a 100644 --- a/tests/unit/scenarios/test_leakage_scenario.py +++ b/tests/unit/scenarios/test_leakage_scenario.py @@ -4,7 +4,6 @@ """Tests for the LeakageScenario class.""" import pathlib -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -119,7 +118,7 @@ def mock_adversarial_target(): @pytest.fixture -def sample_objectives() -> List[str]: +def sample_objectives() -> list[str]: return ["test leakage prompt 1", "test leakage prompt 2"] diff --git a/tests/unit/scenarios/test_psychosocial_harms.py b/tests/unit/scenarios/test_psychosocial_harms.py index 462dcd2a6c..6c683e8ba9 100644 --- a/tests/unit/scenarios/test_psychosocial_harms.py +++ b/tests/unit/scenarios/test_psychosocial_harms.py @@ -3,7 +3,7 @@ """Tests for the PsychosocialScenario class.""" -from typing import Dict, List, Sequence +from collections.abc import Sequence from unittest.mock import MagicMock, patch import pytest @@ -29,7 +29,7 @@ @pytest.fixture -def mock_memory_seed_groups() -> List[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" return [SeedGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] @@ -50,7 +50,7 @@ def imminent_crisis_strategy() -> PsychosocialStrategy: @pytest.fixture -def psychosocial_prompts() -> List[str]: +def psychosocial_prompts() -> list[str]: return SEED_PROMPT_LIST @@ -92,7 +92,7 @@ def mock_adversarial_target() -> PromptChatTarget: @pytest.fixture -def sample_objectives_by_harm() -> Dict[str, Sequence[SeedGroup]]: +def sample_objectives_by_harm() -> dict[str, Sequence[SeedGroup]]: return { "psychosocial_imminent_crisis": [ SeedGroup( @@ -116,7 +116,7 @@ def sample_objectives_by_harm() -> Dict[str, Sequence[SeedGroup]]: @pytest.fixture -def sample_objectives() -> List[str]: +def sample_objectives() -> list[str]: return ["psychosocial prompt 1", "psychosocial prompt 2"] @@ -131,7 +131,7 @@ def test_init_with_custom_objectives( self, *, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test initialization with custom objectives (deprecated parameter).""" scenario = PsychosocialScenario( @@ -230,7 +230,7 @@ async def test_attack_generation_for_all( self, mock_objective_target, mock_objective_scorer, - sample_objectives: List[str], + sample_objectives: list[str], ): """Test that _get_atomic_attacks_async returns atomic attacks.""" scenario = PsychosocialScenario(objectives=sample_objectives, objective_scorer=mock_objective_scorer) @@ -248,7 +248,7 @@ async def test_attack_generation_for_singleturn_async( mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, single_turn_strategy: PsychosocialStrategy, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test that the single turn strategy attack generation works.""" scenario = PsychosocialScenario( @@ -270,7 +270,7 @@ async def test_attack_generation_for_multiturn_async( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], multi_turn_strategy: PsychosocialStrategy, ) -> None: """Test that the multi turn attack generation works.""" @@ -293,7 +293,7 @@ async def test_attack_generation_for_imminent_crisis_async( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], imminent_crisis_strategy: PsychosocialStrategy, ) -> None: """Test that the imminent crisis strategy generates both single and multi-turn attacks.""" @@ -318,7 +318,7 @@ async def test_attack_runs_include_objectives_async( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test that attack runs include objectives for each seed prompt.""" scenario = PsychosocialScenario( @@ -341,7 +341,7 @@ async def test_get_atomic_attacks_async_returns_attacks( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test that _get_atomic_attacks_async returns atomic attacks.""" scenario = PsychosocialScenario( @@ -365,7 +365,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test initialization with custom max_concurrency.""" scenario = PsychosocialScenario(objectives=sample_objectives, objective_scorer=mock_objective_scorer) @@ -378,7 +378,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptChatTarget, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test initialization with memory labels.""" memory_labels = {"type": "psychosocial", "category": "crisis"} @@ -399,7 +399,7 @@ def test_scenario_version_is_set( self, *, mock_objective_scorer: FloatScaleThresholdScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test that scenario version is properly set.""" scenario = PsychosocialScenario( @@ -422,7 +422,7 @@ async def test_no_target_duplication_async( self, *, mock_objective_target: PromptChatTarget, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: """Test that all three targets (adversarial, objective, scorer) are distinct.""" scenario = PsychosocialScenario(objectives=sample_objectives) diff --git a/tests/unit/scenarios/test_scam.py b/tests/unit/scenarios/test_scam.py index bf846792f6..9f5d1c3f0a 100644 --- a/tests/unit/scenarios/test_scam.py +++ b/tests/unit/scenarios/test_scam.py @@ -4,7 +4,6 @@ """Tests for the Scam class.""" import pathlib -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -44,7 +43,7 @@ def _mock_target_id(name: str = "MockTarget") -> ComponentIdentifier: @pytest.fixture -def mock_memory_seed_groups() -> List[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" return [SeedGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] @@ -77,7 +76,7 @@ def multi_turn_strategy() -> ScamStrategy: @pytest.fixture -def scam_prompts() -> List[str]: +def scam_prompts() -> list[str]: return SEED_PROMPT_LIST @@ -119,7 +118,7 @@ def mock_adversarial_target() -> PromptChatTarget: @pytest.fixture -def sample_objectives() -> List[str]: +def sample_objectives() -> list[str]: return ["scam prompt 1", "scam prompt 2"] @@ -134,7 +133,7 @@ def test_init_with_custom_objectives( self, *, mock_objective_scorer: TrueFalseCompositeScorer, - sample_objectives: List[str], + sample_objectives: list[str], ) -> None: scenario = Scam( objectives=sample_objectives, @@ -150,7 +149,7 @@ def test_init_with_default_objectives( self, *, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -166,7 +165,7 @@ def test_init_with_default_scorer(self, mock_memory_seed_groups) -> None: scenario = Scam() assert scenario._objective_scorer_identifier - def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: List[SeedGroup]) -> None: + def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedGroup]) -> None: """Test initialization with custom scorer.""" scorer = MagicMock(spec=TrueFalseCompositeScorer) @@ -175,7 +174,7 @@ def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: List[SeedGrou assert isinstance(scenario._scorer_config, AttackScoringConfig) def test_init_default_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: List[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -184,7 +183,7 @@ def test_init_default_adversarial_chat( assert scenario._adversarial_chat._temperature == 1.2 def test_init_with_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: List[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] ) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") @@ -322,7 +321,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], mock_dataset_config, ) -> None: """Test initialization with custom max_concurrency.""" @@ -339,7 +338,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: List[SeedGroup], + mock_memory_seed_groups: list[SeedGroup], mock_dataset_config, ) -> None: """Test initialization with memory labels.""" @@ -373,7 +372,7 @@ def test_scenario_version_is_set( @pytest.mark.asyncio async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: List[SeedGroup], mock_dataset_config + self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup], mock_dataset_config ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): diff --git a/tests/unit/score/test_batch_scorer.py b/tests/unit/score/test_batch_scorer.py index ffe180223c..32ed893f70 100644 --- a/tests/unit/score/test_batch_scorer.py +++ b/tests/unit/score/test_batch_scorer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import uuid -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/score/test_prompt_shield_scorer.py b/tests/unit/score/test_prompt_shield_scorer.py index 6311267299..b2f284dfe8 100644 --- a/tests/unit/score/test_prompt_shield_scorer.py +++ b/tests/unit/score/test_prompt_shield_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import Mock import pytest diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 934e74e12a..c6b3c11150 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -282,7 +282,7 @@ async def test_scorer_send_chat_target_async_good_response(good_json): objective="task", ) - assert chat_target.send_prompt_async.call_count == int(1) + assert chat_target.send_prompt_async.call_count == 1 @pytest.mark.asyncio diff --git a/tests/unit/score/test_scorer_eval_csv_schema.py b/tests/unit/score/test_scorer_eval_csv_schema.py index 75df9f67c3..5e0a025430 100644 --- a/tests/unit/score/test_scorer_eval_csv_schema.py +++ b/tests/unit/score/test_scorer_eval_csv_schema.py @@ -10,7 +10,6 @@ import csv from pathlib import Path -from typing import List import pytest @@ -32,11 +31,11 @@ class TestObjectiveScorerEvalCSVSchema: """Test that all objective scorer evaluation CSVs have the correct schema.""" @pytest.fixture(scope="class") - def objective_csv_files(self) -> List[Path]: + def objective_csv_files(self) -> list[Path]: """Get all CSV files in the objective scorer evals directory.""" return list(Path(SCORER_EVALS_OBJECTIVE_PATH).glob("*.csv")) - def test_objective_csv_files_exist(self, objective_csv_files: List[Path]) -> None: + def test_objective_csv_files_exist(self, objective_csv_files: list[Path]) -> None: """Verify that objective CSV files exist.""" assert len(objective_csv_files) > 0, "No objective CSV files found" @@ -55,7 +54,7 @@ def test_objective_csv_has_required_columns(self, csv_file: Path) -> None: - human_score: The human-labeled ground truth score - data_type: The type of data (e.g., "text") """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -88,7 +87,7 @@ def test_objective_csv_column_names_exact(self, csv_file: Path) -> None: This ensures no typos or legacy column names remain. """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -116,11 +115,11 @@ class TestHarmScorerEvalCSVSchema: """Test that all harm scorer evaluation CSVs have the correct schema.""" @pytest.fixture(scope="class") - def harm_csv_files(self) -> List[Path]: + def harm_csv_files(self) -> list[Path]: """Get all CSV files in the harm scorer evals directory.""" return list(Path(SCORER_EVALS_HARM_PATH).glob("*.csv")) - def test_harm_csv_files_exist(self, harm_csv_files: List[Path]) -> None: + def test_harm_csv_files_exist(self, harm_csv_files: list[Path]) -> None: """Verify that harm CSV files exist.""" assert len(harm_csv_files) > 0, "No harm CSV files found" @@ -143,7 +142,7 @@ def test_harm_csv_has_required_columns(self, csv_file: Path) -> None: Note: Harm CSVs may have additional human_score_2, human_score_3, etc. for multi-annotator datasets. """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -179,7 +178,7 @@ def test_harm_csv_has_human_score_columns(self, csv_file: Path) -> None: - human_score_2 (optional) - human_score_3 (optional) """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -203,11 +202,11 @@ class TestRefusalScorerEvalCSVSchema: """Test that all refusal scorer evaluation CSVs have the correct schema.""" @pytest.fixture(scope="class") - def refusal_csv_files(self) -> List[Path]: + def refusal_csv_files(self) -> list[Path]: """Get all CSV files in the refusal scorer evals directory.""" return list(Path(SCORER_EVALS_REFUSAL_SCORER_PATH).glob("*.csv")) - def test_refusal_csv_files_exist(self, refusal_csv_files: List[Path]) -> None: + def test_refusal_csv_files_exist(self, refusal_csv_files: list[Path]) -> None: """Verify that refusal CSV files exist.""" assert len(refusal_csv_files) > 0, "No refusal CSV files found" @@ -226,7 +225,7 @@ def test_refusal_csv_has_required_columns(self, csv_file: Path) -> None: - human_score: The human-labeled ground truth score - data_type: The type of data (e.g., "text") """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -258,7 +257,7 @@ def test_refusal_csv_column_names_exact(self, csv_file: Path) -> None: This ensures no typos or legacy column names remain. """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: # Skip version line if present first_line = f.readline() if not first_line.startswith("# dataset_version="): @@ -284,7 +283,7 @@ class TestCSVVersionMetadata: """Test that all CSV files have version metadata.""" @pytest.fixture(scope="class") - def all_csv_files(self) -> List[Path]: + def all_csv_files(self) -> list[Path]: """Get all CSV files from all scorer eval directories.""" files: list[Path] = [] files.extend(Path(SCORER_EVALS_OBJECTIVE_PATH).glob("*.csv")) @@ -305,7 +304,7 @@ def test_csv_has_dataset_version_line(self, csv_file: Path) -> None: Version line format: # dataset_version=X.Y """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: first_line = f.readline().strip() assert first_line.startswith("#") and "dataset_version=" in first_line, ( @@ -324,7 +323,7 @@ def test_harm_csv_has_harm_definition(self, csv_file: Path) -> None: Format: # dataset_version=X.Y, harm_definition=path/to/definition.yaml """ - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: first_line = f.readline().strip() assert "harm_definition=" in first_line, ( @@ -348,7 +347,7 @@ def test_harm_definition_file_exists_and_is_valid(self, csv_file: Path) -> None: """ from pyrit.models.harm_definition import HarmDefinition - with open(csv_file, "r", encoding="utf-8") as f: + with open(csv_file, encoding="utf-8") as f: first_line = f.readline().strip() # Parse harm_definition from the comment line diff --git a/tests/unit/score/test_self_ask_scale.py b/tests/unit/score/test_self_ask_scale.py index fada168b0f..824e594e84 100644 --- a/tests/unit/score/test_self_ask_scale.py +++ b/tests/unit/score/test_self_ask_scale.py @@ -229,4 +229,4 @@ async def test_scale_scorer_score_calls_send_chat(patch_central_database): scorer._score_value_with_llm = AsyncMock(return_value=score) await scorer.score_text_async(text="example text", objective="task") - assert scorer._score_value_with_llm.call_count == int(1) + assert scorer._score_value_with_llm.call_count == 1 diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index a395b0ea08..19a168c980 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -5,7 +5,6 @@ Unit tests for LoadDefaultDatasets initializer. """ -from typing import List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -195,8 +194,8 @@ async def test_all_required_datasets_available_in_seed_provider(self) -> None: scenario_names = registry.get_names() # Collect all required datasets from all scenarios - missing_datasets: List[str] = [] - scenario_dataset_map: dict[str, List[str]] = {} + missing_datasets: list[str] = [] + scenario_dataset_map: dict[str, list[str]] = {} for scenario_name in scenario_names: scenario_class = registry.get_class(scenario_name) diff --git a/tests/unit/target/test_azure_ml_chat_target.py b/tests/unit/target/test_azure_ml_chat_target.py index abd29f9fb1..6c49d4a014 100644 --- a/tests/unit/target/test_azure_ml_chat_target.py +++ b/tests/unit/target/test_azure_ml_chat_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import os -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_azure_openai_completion_target.py b/tests/unit/target/test_azure_openai_completion_target.py index 26d5d2a0cd..a4bb4160eb 100644 --- a/tests/unit/target/test_azure_openai_completion_target.py +++ b/tests/unit/target/test_azure_openai_completion_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import os -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index e75dfdc278..39d3e6cd39 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Callable +from collections.abc import Callable from unittest.mock import AsyncMock, MagicMock, patch import httpx diff --git a/tests/unit/target/test_http_target_parsing.py b/tests/unit/target/test_http_target_parsing.py index 5840c4e3ca..a99ea625e8 100644 --- a/tests/unit/target/test_http_target_parsing.py +++ b/tests/unit/target/test_http_target_parsing.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. -from typing import Callable +from collections.abc import Callable from unittest.mock import MagicMock import pytest diff --git a/tests/unit/target/test_image_target.py b/tests/unit/target/test_image_target.py index 30d5726000..d25655f6c9 100644 --- a/tests/unit/target/test_image_target.py +++ b/tests/unit/target/test_image_target.py @@ -3,7 +3,7 @@ import os import uuid -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index bc17cdc052..0fe473a2db 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -5,8 +5,8 @@ import json import logging import os +from collections.abc import MutableSequence from tempfile import NamedTemporaryFile -from typing import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import httpx diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 1e4b95e0b5..33114cddef 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -3,8 +3,9 @@ import json import os +from collections.abc import MutableSequence from tempfile import NamedTemporaryFile -from typing import Any, MutableSequence +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_playwright_target.py b/tests/unit/target/test_playwright_target.py index ce91caf789..b385e4e7e9 100644 --- a/tests/unit/target/test_playwright_target.py +++ b/tests/unit/target/test_playwright_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock import pytest diff --git a/tests/unit/target/test_prompt_shield_target.py b/tests/unit/target/test_prompt_shield_target.py index 3ac3525f79..85866537b1 100644 --- a/tests/unit/target/test_prompt_shield_target.py +++ b/tests/unit/target/test_prompt_shield_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import MagicMock import pytest diff --git a/tests/unit/target/test_prompt_target.py b/tests/unit/target/test_prompt_target.py index 9e9c717a50..7ed2f37be7 100644 --- a/tests/unit/target/test_prompt_target.py +++ b/tests/unit/target/test_prompt_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_prompt_target_azure_blob_storage.py b/tests/unit/target/test_prompt_target_azure_blob_storage.py index 85438c5853..9fbe5f28b9 100644 --- a/tests/unit/target/test_prompt_target_azure_blob_storage.py +++ b/tests/unit/target/test_prompt_target_azure_blob_storage.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import os -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/target/test_prompt_target_text.py b/tests/unit/target/test_prompt_target_text.py index 5578fa7f00..ceee9eaa00 100644 --- a/tests/unit/target/test_prompt_target_text.py +++ b/tests/unit/target/test_prompt_target_text.py @@ -3,8 +3,8 @@ import io import os +from collections.abc import MutableSequence from tempfile import NamedTemporaryFile -from typing import MutableSequence import pytest from unit.mocks import get_sample_conversations diff --git a/tests/unit/target/test_tts_target.py b/tests/unit/target/test_tts_target.py index 08c6f37526..a042d18310 100644 --- a/tests/unit/target/test_tts_target.py +++ b/tests/unit/target/test_tts_target.py @@ -3,7 +3,7 @@ import os import uuid -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -157,7 +157,7 @@ async def test_tts_send_prompt_async_exception_adds_to_memory( with patch.object(tts_target._async_client.audio.speech, "create", new_callable=AsyncMock) as mock_create: mock_create.side_effect = sdk_exception - with pytest.raises((exception_class)): + with pytest.raises(exception_class): await tts_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_video_target.py b/tests/unit/target/test_video_target.py index 38213c21d4..eab0d81ac4 100644 --- a/tests/unit/target/test_video_target.py +++ b/tests/unit/target/test_video_target.py @@ -3,7 +3,7 @@ import json import uuid -from typing import MutableSequence +from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest