diff --git a/build_scripts/conditional_jb_build.py b/build_scripts/conditional_jb_build.py index 90c4696aea..0258aa1f3b 100644 --- a/build_scripts/conditional_jb_build.py +++ b/build_scripts/conditional_jb_build.py @@ -27,9 +27,8 @@ def main(): cwd=os.path.dirname(os.path.dirname(__file__)), # Repository root ) return result.returncode - else: - print("RUN_LONG_PRECOMMIT not set: Skipping full Jupyter Book build (fast validation runs instead)") - return 0 + print("RUN_LONG_PRECOMMIT not set: Skipping full Jupyter Book build (fast validation runs instead)") + return 0 if __name__ == "__main__": diff --git a/build_scripts/prepare_package.py b/build_scripts/prepare_package.py index 1ed307d5c0..0eec9f3d69 100644 --- a/build_scripts/prepare_package.py +++ b/build_scripts/prepare_package.py @@ -110,9 +110,8 @@ def copy_frontend_to_package(frontend_dist: Path, backend_frontend: Path) -> boo if index_html.exists(): print("✓ Frontend successfully copied to package") return True - else: - print("ERROR: index.html not found after copy") - return False + print("ERROR: index.html not found after copy") + return False def main(): diff --git a/build_scripts/validate_jupyter_book.py b/build_scripts/validate_jupyter_book.py index afbdf40938..1b067207ef 100644 --- a/build_scripts/validate_jupyter_book.py +++ b/build_scripts/validate_jupyter_book.py @@ -327,9 +327,8 @@ def main(): print(f" • {error}", file=sys.stderr) print("=" * 80, file=sys.stderr) return 1 - else: - print("\n[OK] All Jupyter Book validations passed!") - return 0 + print("\n[OK] All Jupyter Book validations passed!") + return 0 if __name__ == "__main__": diff --git a/doc/deployment/deploy_hf_model_aml.ipynb b/doc/deployment/deploy_hf_model_aml.ipynb index 58f7e22bb9..08db471277 100644 --- a/doc/deployment/deploy_hf_model_aml.ipynb +++ b/doc/deployment/deploy_hf_model_aml.ipynb @@ -222,9 +222,7 @@ " # Generate a 5-char random alphanumeric string and append to '-'\n", " random_suffix = \"-\" + \"\".join(random.choices(string.ascii_letters + string.digits, k=5))\n", "\n", - " updated_name = f\"{base_name}{random_suffix}\"\n", - "\n", - " return updated_name" + " return f\"{base_name}{random_suffix}\"" ] }, { diff --git a/doc/deployment/deploy_hf_model_aml.py b/doc/deployment/deploy_hf_model_aml.py index 5607472b83..8e6029c234 100644 --- a/doc/deployment/deploy_hf_model_aml.py +++ b/doc/deployment/deploy_hf_model_aml.py @@ -185,9 +185,7 @@ def get_updated_endpoint_name(endpoint_name): # Generate a 5-char random alphanumeric string and append to '-' random_suffix = "-" + "".join(random.choices(string.ascii_letters + string.digits, k=5)) - updated_name = f"{base_name}{random_suffix}" - - return updated_name + return f"{base_name}{random_suffix}" # %% diff --git a/doc/deployment/download_and_register_hf_model_aml.ipynb b/doc/deployment/download_and_register_hf_model_aml.ipynb index 095c32a356..0d25adaa0b 100644 --- a/doc/deployment/download_and_register_hf_model_aml.ipynb +++ b/doc/deployment/download_and_register_hf_model_aml.ipynb @@ -340,8 +340,7 @@ "\n", " # Find the model with the maximum version number\n", " max_version = max(models, key=lambda x: int(x.version)).version\n", - " model_max_version = str(int(max_version))\n", - " return model_max_version" + " return str(int(max_version))" ] }, { diff --git a/doc/deployment/download_and_register_hf_model_aml.py b/doc/deployment/download_and_register_hf_model_aml.py index 70aa412bf3..e650a0568e 100644 --- a/doc/deployment/download_and_register_hf_model_aml.py +++ b/doc/deployment/download_and_register_hf_model_aml.py @@ -262,8 +262,7 @@ def get_max_model_version(models: list) -> str: # Find the model with the maximum version number max_version = max(models, key=lambda x: int(x.version)).version - model_max_version = str(int(max_version)) - return model_max_version + return str(int(max_version)) # %% diff --git a/frontend/dev.py b/frontend/dev.py index 71acafeb53..f699ccf84e 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -120,9 +120,7 @@ def start_backend(initializers: list[str] | None = None): cmd.extend(["--initializers"] + initializers) # Start backend - backend = subprocess.Popen(cmd, env=env) - - return backend + return subprocess.Popen(cmd, env=env) def start_frontend(): @@ -134,9 +132,7 @@ def start_frontend(): # Start frontend process npm_cmd = "npm.cmd" if is_windows() else "npm" - frontend = subprocess.Popen([npm_cmd, "run", "dev"]) - - return frontend + return subprocess.Popen([npm_cmd, "run", "dev"]) def start_servers(): @@ -199,7 +195,7 @@ def main(): if command == "stop": stop_servers() return - elif command == "restart": + if command == "restart": stop_servers() time.sleep(1) elif command == "start": diff --git a/pyproject.toml b/pyproject.toml index 5cf8c7f4c6..d0fa09f6bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -252,6 +252,7 @@ select = [ "DOC", # https://docs.astral.sh/ruff/rules/#pydoclint-doc "F401", # unused-import "I", # isort + "RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret "W", # https://docs.astral.sh/ruff/rules/#pycodestyle-w ] ignore = [ diff --git a/pyrit/analytics/text_matching.py b/pyrit/analytics/text_matching.py index 627f8ff312..620d3d50fc 100644 --- a/pyrit/analytics/text_matching.py +++ b/pyrit/analytics/text_matching.py @@ -69,8 +69,7 @@ def is_match(self, *, target: str, text: str) -> bool: text = text.strip() if self._case_sensitive: return target in text - else: - return target.lower() in text.lower() + return target.lower() in text.lower() class ApproximateTextMatching(TextMatching): @@ -141,9 +140,7 @@ def _calculate_ngram_overlap(self, *, target: str, text: str) -> float: matching_ngrams = sum(int(ngram in text_str) for ngram in target_ngrams) # Calculate proportion of matching n-grams - score = matching_ngrams / len(target_ngrams) - - return score + return matching_ngrams / len(target_ngrams) def get_overlap_score(self, *, target: str, text: str) -> float: """ diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 44b801600d..d902e419c4 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -221,8 +221,7 @@ def get_azure_token_provider(scope: str) -> Callable[[], str]: >>> token = token_provider() # Get current token """ try: - token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope) - return token_provider + return get_bearer_token_provider(DefaultAzureCredential(), scope) except Exception as e: logger.error(f"Failed to obtain token provider for '{scope}': {e}") raise @@ -246,8 +245,7 @@ def get_azure_async_token_provider(scope: str): # type: ignore[no-untyped-def] >>> token = await token_provider() # Get current token (in async context) """ try: - token_provider = get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope) - return token_provider + return get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope) except Exception as e: logger.error(f"Failed to obtain async token provider for '{scope}': {e}") raise @@ -334,13 +332,12 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi subscription=key, region=region, ) - elif resource_id and region: + if resource_id and region: return get_speech_config_from_default_azure_credential( resource_id=resource_id, region=region, ) - else: - raise ValueError("Insufficient information provided for Azure Speech service.") + raise ValueError("Insufficient information provided for Azure Speech service.") def get_speech_config_from_default_azure_credential(resource_id: str, region: str) -> speechsdk.SpeechConfig: @@ -370,11 +367,10 @@ def get_speech_config_from_default_azure_credential(resource_id: str, region: st azure_auth = AzureAuth(token_scope=get_default_azure_scope("")) token = azure_auth.get_token() authorization_token = "aad#" + resource_id + "#" + token - speech_config = speechsdk.SpeechConfig( + return speechsdk.SpeechConfig( auth_token=authorization_token, region=region, ) - return speech_config except Exception as e: logger.error(f"Failed to get speech config for resource ID '{resource_id}' and region '{region}': {e}") raise diff --git a/pyrit/auth/azure_storage_auth.py b/pyrit/auth/azure_storage_auth.py index 5f69e2d625..f24576aa62 100644 --- a/pyrit/auth/azure_storage_auth.py +++ b/pyrit/auth/azure_storage_auth.py @@ -34,12 +34,10 @@ async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> Use delegation_key_start_time = datetime.now() delegation_key_expiry_time = delegation_key_start_time + timedelta(days=1) - user_delegation_key = await blob_service_client.get_user_delegation_key( + return await blob_service_client.get_user_delegation_key( key_start_time=delegation_key_start_time, key_expiry_time=delegation_key_expiry_time ) - return user_delegation_key - @staticmethod async def get_sas_token(container_url: str) -> str: """ diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 8f80f89024..865768ff30 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -208,7 +208,7 @@ async def _get_cached_token_if_available_and_valid(self) -> Optional[dict[str, A if not cached_user: logger.info("No user associated with cached token. Token invalidated.") return None - elif cached_user != self._username: + if cached_user != self._username: logger.info( f"Cached token is for different user (cached: {cached_user}, current: {self._username}). " "Token invalidated." diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 6a19ababb2..02edf7f611 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -66,44 +66,41 @@ def default(self, obj: Any) -> Any: def get_embedding_layer(model: Any) -> Any: if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel): return model.transformer.wte - elif isinstance(model, LlamaForCausalLM): + if isinstance(model, LlamaForCausalLM): return model.model.embed_tokens - elif isinstance(model, GPTNeoXForCausalLM): + if isinstance(model, GPTNeoXForCausalLM): return model.base_model.embed_in - elif isinstance(model, Phi3ForCausalLM): + if isinstance(model, Phi3ForCausalLM): return model.model.embed_tokens - else: - raise ValueError(f"Unknown model type: {type(model)}") + raise ValueError(f"Unknown model type: {type(model)}") def get_embedding_matrix(model: Any) -> Any: if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel): return model.transformer.wte.weight - elif isinstance(model, LlamaForCausalLM): + if isinstance(model, LlamaForCausalLM): return model.model.embed_tokens.weight - elif isinstance(model, GPTNeoXForCausalLM): + if isinstance(model, GPTNeoXForCausalLM): return model.base_model.embed_in.weight # type: ignore[union-attr, unused-ignore] - elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM): + if isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM): return model.model.embed_tokens.weight - elif isinstance(model, Phi3ForCausalLM): + if isinstance(model, Phi3ForCausalLM): return model.model.embed_tokens.weight - else: - raise ValueError(f"Unknown model type: {type(model)}") + raise ValueError(f"Unknown model type: {type(model)}") def get_embeddings(model: Any, input_ids: torch.Tensor) -> Any: if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel): return model.transformer.wte(input_ids).half() - elif isinstance(model, LlamaForCausalLM): + if isinstance(model, LlamaForCausalLM): return model.model.embed_tokens(input_ids) - elif isinstance(model, GPTNeoXForCausalLM): + if isinstance(model, GPTNeoXForCausalLM): return model.base_model.embed_in(input_ids).half() # type: ignore[operator, unused-ignore] - elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM): + if isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM): return model.model.embed_tokens(input_ids) - elif isinstance(model, Phi3ForCausalLM): + if isinstance(model, Phi3ForCausalLM): return model.model.embed_tokens(input_ids) - else: - raise ValueError(f"Unknown model type: {type(model)}") + raise ValueError(f"Unknown model type: {type(model)}") def get_nonascii_toks(tokenizer: Any, device: str = "cpu") -> torch.Tensor: @@ -364,24 +361,21 @@ def logits(self, model: Any, test_controls: Any = None, return_ids: bool = False del locs, test_ids gc.collect() return model(input_ids=ids, attention_mask=attn_mask).logits, ids - else: - del locs, test_ids - logits = model(input_ids=ids, attention_mask=attn_mask).logits - del ids - gc.collect() - return logits + del locs, test_ids + logits = model(input_ids=ids, attention_mask=attn_mask).logits + del ids + gc.collect() + return logits def target_loss(self, logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: crit = nn.CrossEntropyLoss(reduction="none") loss_slice = slice(self._target_slice.start - 1, self._target_slice.stop - 1) - loss = crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._target_slice]) - return loss # type: ignore[no-any-return, unused-ignore] + return crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._target_slice]) # type: ignore[no-any-return] def control_loss(self, logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: crit = nn.CrossEntropyLoss(reduction="none") loss_slice = slice(self._control_slice.start - 1, self._control_slice.stop - 1) - loss = crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._control_slice]) - return loss # type: ignore[no-any-return, unused-ignore] + return crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._control_slice]) # type: ignore[no-any-return] @property def assistant_str(self) -> Any: @@ -527,8 +521,7 @@ def logits(self, model: Any, test_controls: Any = None, return_ids: bool = False vals = [prompt.logits(model, test_controls, return_ids) for prompt in self._prompts] if return_ids: return [val[0] for val in vals], [val[1] for val in vals] - else: - return vals + return vals def target_loss(self, logits: list[torch.Tensor], ids: list[torch.Tensor]) -> torch.Tensor: return torch.cat( diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 1016a65dc0..3fb5a8aa46 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -119,8 +119,7 @@ def sample_control( new_token_val = torch.gather( top_indices[new_token_pos], 1, torch.randint(0, topk, (batch_size, 1), device=grad.device) ) - new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val) - return new_control_toks + return original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val) class GCGMultiPromptAttack(MultiPromptAttack): diff --git a/pyrit/auxiliary_attacks/gcg/experiments/train.py b/pyrit/auxiliary_attacks/gcg/experiments/train.py index 352b5d9d7d..1eac86dc56 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/train.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/train.py @@ -305,19 +305,18 @@ def _create_attack( mpa_batch_size=params.batch_size, mpa_n_steps=params.n_steps, ) - else: - return IndividualPromptAttack( - train_goals, - train_targets, - workers, - control_init=params.control_init, - logfile=f"{params.result_prefix}_{timestamp}.json", - managers=managers, - test_goals=getattr(params, "test_goals", []), - test_targets=getattr(params, "test_targets", []), - test_workers=test_workers, - mpa_deterministic=params.gbda_deterministic, - mpa_lr=params.learning_rate, - mpa_batch_size=params.batch_size, - mpa_n_steps=params.n_steps, - ) + return IndividualPromptAttack( + train_goals, + train_targets, + workers, + control_init=params.control_init, + logfile=f"{params.result_prefix}_{timestamp}.json", + managers=managers, + test_goals=getattr(params, "test_goals", []), + test_targets=getattr(params, "test_targets", []), + test_workers=test_workers, + mpa_deterministic=params.gbda_deterministic, + mpa_lr=params.learning_rate, + mpa_batch_size=params.batch_size, + mpa_n_steps=params.n_steps, + ) diff --git a/pyrit/common/path.py b/pyrit/common/path.py index fcdc4c92ec..f04718cf2a 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -15,8 +15,7 @@ def get_default_data_path(dir: str) -> pathlib.Path: """ if in_git_repo(): return pathlib.Path(PYRIT_PATH, "..", dir).resolve() - else: - return pathlib.Path(user_data_dir(dir, "pyrit")).resolve() + return pathlib.Path(user_data_dir(dir, "pyrit")).resolve() def in_git_repo() -> bool: diff --git a/pyrit/common/question_answer_helpers.py b/pyrit/common/question_answer_helpers.py index 350b623573..4ffa71411a 100644 --- a/pyrit/common/question_answer_helpers.py +++ b/pyrit/common/question_answer_helpers.py @@ -15,5 +15,4 @@ def construct_evaluation_prompt(entry: QuestionAnsweringEntry) -> str: str: The constructed prompt string. """ available_choices = "\n".join(f"index={c.index}, value={c.text}" for c in entry.choices) - prompt = f"Question:\n{entry.question}\n\nChoices:\n{available_choices}" - return prompt + return f"Question:\n{entry.question}\n\nChoices:\n{available_choices}" diff --git a/pyrit/common/utils.py b/pyrit/common/utils.py index fe5bffdfd9..401361eb79 100644 --- a/pyrit/common/utils.py +++ b/pyrit/common/utils.py @@ -75,8 +75,7 @@ def combine_list(list1: Union[str, List[str]], list2: Union[str, List[str]]) -> list2 = [list2] # Merge and keep only unique values - combined = list(set(list1 + list2)) - return combined + return list(set(list1 + list2)) def get_random_indices(*, start: int, size: int, proportion: float) -> List[int]: diff --git a/pyrit/common/yaml_loadable.py b/pyrit/common/yaml_loadable.py index 44a9e3beac..7b03a7fc03 100644 --- a/pyrit/common/yaml_loadable.py +++ b/pyrit/common/yaml_loadable.py @@ -42,5 +42,4 @@ def from_yaml_file(cls: Type[T], file: Union[Path | str]) -> T: # otherwise, just instantiate directly with **yaml_data if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): return cls.from_dict(yaml_data) # type: ignore - else: - return cls(**yaml_data) + return cls(**yaml_data) diff --git a/pyrit/datasets/executors/question_answer/wmdp_dataset.py b/pyrit/datasets/executors/question_answer/wmdp_dataset.py index ed91730989..61487b5229 100644 --- a/pyrit/datasets/executors/question_answer/wmdp_dataset.py +++ b/pyrit/datasets/executors/question_answer/wmdp_dataset.py @@ -57,7 +57,7 @@ def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDatas ) questions_answers.append(entry) - dataset = QuestionAnsweringDataset( + return QuestionAnsweringDataset( name="wmdp", description="""The WMDP Benchmark: Measuring and Reducing Malicious Use With Unlearning. The Weapons of Mass Destruction Proxy (WMDP) benchmark is a dataset of 4,157 multiple-choice questions surrounding hazardous @@ -79,5 +79,3 @@ def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDatas source="https://huggingface.co/datasets/cais/wmdp", questions=questions_answers, ) - - return dataset diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index 7e13c697f3..5b61974ef5 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -210,13 +210,13 @@ def _map_jbb_category_to_harm_category(self, jbb_category: str) -> list[str]: # Special handling for common patterns if any(term in jbb_category_lower for term in ["violent", "kill", "murder", "bomb"]): return ["violence"] - elif any(term in jbb_category_lower for term in ["hate", "racist", "sexist"]): + if any(term in jbb_category_lower for term in ["hate", "racist", "sexist"]): return ["hate", "discrimination"] - elif any(term in jbb_category_lower for term in ["sexual", "porn", "nsfw"]): + if any(term in jbb_category_lower for term in ["sexual", "porn", "nsfw"]): return ["sexual"] - elif any(term in jbb_category_lower for term in ["illegal", "crime", "criminal"]): + if any(term in jbb_category_lower for term in ["illegal", "crime", "criminal"]): return ["criminal_planning", "illegal_activity"] - elif any(term in jbb_category_lower for term in ["harm", "hurt", "damage"]): + if any(term in jbb_category_lower for term in ["harm", "hurt", "damage"]): return ["violence", "harm"] # Default: use the original JBB category diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 1c9d538824..216c6d08e7 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -130,16 +130,13 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> List[Dict[st if file_type in FILE_TYPE_HANDLERS: if file_type == "json": return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) - else: - return cast( - List[Dict[str, str]], - FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), - ) - else: - valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") - else: - raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") + return cast( + List[Dict[str, str]], + FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), + ) + valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") def _fetch_from_file(self, *, source: str, file_type: str) -> List[Dict[str, str]]: """ @@ -158,9 +155,8 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> List[Dict[str, str with open(source, "r", encoding="utf-8") as file: if file_type in FILE_TYPE_HANDLERS: return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) - else: - valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") def _fetch_from_url( self, @@ -270,7 +266,7 @@ def _load_dataset_sync() -> Any: cache_dir = str(DB_DATA_PATH / "huggingface") if cache else None # Explicitly set download_mode to reuse cached data and never re-download - dataset = load_dataset( + return load_dataset( dataset_name, config, split=split, @@ -279,12 +275,10 @@ def _load_dataset_sync() -> Any: token=token, **kwargs, ) - return dataset try: # Run the synchronous load_dataset in a thread pool to avoid blocking the event loop - dataset = await asyncio.to_thread(_load_dataset_sync) - return dataset + return await asyncio.to_thread(_load_dataset_sync) except Exception as e: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise diff --git a/pyrit/embedding/_text_embedding.py b/pyrit/embedding/_text_embedding.py index 55857c7259..71d06abc34 100644 --- a/pyrit/embedding/_text_embedding.py +++ b/pyrit/embedding/_text_embedding.py @@ -41,7 +41,7 @@ def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse The embedding response """ embedding_obj = self._client.embeddings.create(input=text, model=self._model, **kwargs) - embedding_response = EmbeddingResponse( + return EmbeddingResponse( model=embedding_obj.model, object=embedding_obj.object, data=[ @@ -56,4 +56,3 @@ def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse total_tokens=embedding_obj.usage.total_tokens, ), ) - return embedding_response diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index 157e5676f9..5640a6a27b 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -82,7 +82,7 @@ async def generate_text_embedding_async(self, text: str, **kwargs: Any) -> Embed The embedding response """ embedding_obj = await self._async_client.embeddings.create(input=text, model=self._model, **kwargs) - embedding_response = EmbeddingResponse( + return EmbeddingResponse( model=embedding_obj.model, object=embedding_obj.object, data=[ @@ -97,7 +97,6 @@ async def generate_text_embedding_async(self, text: str, **kwargs: Any) -> Embed total_tokens=embedding_obj.usage.total_tokens, ), ) - return embedding_response def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse: """ diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 51d9f2f80b..873a2fa7f1 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -275,7 +275,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac # Determine the outcome outcome, outcome_reason = self._determine_attack_outcome(response=response, score=score, context=context) - result = AttackResult( + return AttackResult( conversation_id=context.session.conversation_id, objective=context.objective, attack_identifier=self.get_identifier(), @@ -287,8 +287,6 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac executed_turns=context.executed_turns, ) - return result - def _determine_attack_outcome( self, *, diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 33b2c75d75..132abc0a4a 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -445,7 +445,7 @@ def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[A prompt_text += f"\n\n{context.last_score.score_rationale}" return prompt_text - elif response_piece.is_blocked(): + if response_piece.is_blocked(): return RedTeamingAttack.DEFAULT_ADVERSARIAL_PROMPT_IF_OBJECTIVE_TARGET_IS_BLOCKED return f"Request to target failed: {response_piece.response_error}" diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index e62a80cabc..5d3ca0728d 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -337,12 +337,11 @@ def _get_audio_mime_type(self, *, audio_path: str) -> str: """ if audio_path.lower().endswith(".wav"): return "audio/wav" - elif audio_path.lower().endswith(".ogg"): + if audio_path.lower().endswith(".ogg"): return "audio/ogg" - elif audio_path.lower().endswith(".m4a"): + if audio_path.lower().endswith(".m4a"): return "audio/mp4" - else: - return "audio/mpeg" # Default fallback for .mp3, .mpeg, and unknown formats + return "audio/mpeg" # Default fallback for .mp3, .mpeg, and unknown formats def _format_image_content(self, *, image_path: str) -> List[str]: """ @@ -436,14 +435,12 @@ async def _format_piece_content_async(self, *, piece: MessagePiece, show_origina """ if piece.converted_value_data_type == "image_path": return self._format_image_content(image_path=piece.converted_value) - elif piece.converted_value_data_type == "audio_path": + if piece.converted_value_data_type == "audio_path": return self._format_audio_content(audio_path=piece.converted_value) - else: - # Handle text content (including errors) - if piece.has_error(): - return self._format_error_content(piece=piece) - else: - return self._format_text_content(piece=piece, show_original=show_original) + # Handle text content (including errors) + if piece.has_error(): + return self._format_error_content(piece=piece) + return self._format_text_content(piece=piece, show_original=show_original) def _format_message_scores(self, message: Message) -> List[str]: """ diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 3b3f623506..d18dba6013 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -227,7 +227,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta # Determine the outcome outcome, outcome_reason = self._determine_attack_outcome(response=response, score=score, context=context) - result = AttackResult( + return AttackResult( conversation_id=context.conversation_id, objective=context.objective, attack_identifier=self.get_identifier(), @@ -239,8 +239,6 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta executed_turns=1, ) - return result - def _determine_attack_outcome( self, *, response: Optional[Message], score: Optional[Score], context: SingleTurnAttackContext[Any] ) -> tuple[AttackOutcome, Optional[str]]: diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 3dc6005b6e..341caa7910 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -217,15 +217,13 @@ async def _run_experiment(self, context: FairnessBiasBenchmarkContext) -> Attack if not context.generated_message: raise ValueError("Message must be generated before running experiment") - attack_result = await self._prompt_sending_attack.execute_async( + return await self._prompt_sending_attack.execute_async( objective=context.generated_objective, next_message=context.generated_message, prepended_conversation=context.prepended_conversation, memory_labels=context.memory_labels, ) - return attack_result - def _format_experiment_results( self, context: FairnessBiasBenchmarkContext, attack_result: AttackResult, experiment_num: int ) -> Dict[str, Any]: @@ -243,7 +241,7 @@ def _format_experiment_results( conversation_pieces = self.memory.get_conversation(conversation_id=attack_result.conversation_id) response = conversation_pieces[1].get_value() if len(conversation_pieces) >= 2 else "" subject_name = self._extract_name(response) - experiment_result = { + return { "experiment_number": str(experiment_num + 1), "subject": context.subject, "story_type": context.story_type, @@ -254,7 +252,6 @@ def _format_experiment_results( attack_result.last_score.score_rationale if attack_result.last_score else "Scoring not available" ), } - return experiment_result def _format_story_prompt(self, *, subject: str, story_type: str) -> str: """ diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 889be02265..965e852d33 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -274,9 +274,8 @@ async def _prepare_examples_async(self, *, context: AnecdoctorContext) -> str: if self._processing_model: # Extract knowledge graph from examples using the processing model return await self._extract_knowledge_graph_async(context=context) - else: - # Use few-shot examples directly without knowledge graph extraction - return self._format_few_shot_examples(evaluation_data=context.evaluation_data) + # Use few-shot examples directly without knowledge graph extraction + return self._format_few_shot_examples(evaluation_data=context.evaluation_data) async def _send_examples_to_target_async( self, *, formatted_examples: str, context: AnecdoctorContext diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index aacfe84737..6617140e4f 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -998,7 +998,7 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts """ requests = self._create_normalizer_requests(prompts) - responses = await self._prompt_normalizer.send_prompt_batch_to_target_async( + return await self._prompt_normalizer.send_prompt_batch_to_target_async( requests=requests, target=self._objective_target, labels=context.memory_labels, @@ -1006,8 +1006,6 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts batch_size=self._batch_size, ) - return responses - def _create_normalizer_requests(self, prompts: List[str]) -> List[NormalizerRequest]: """ Create normalizer requests from prompts. @@ -1048,12 +1046,10 @@ async def _score_responses_async(self, *, responses: List[Message], tasks: List[ response_pieces = [response.message_pieces[0] for response in responses] # Score with objective scorer - scores = await self._scorer.score_prompts_batch_async( + return await self._scorer.score_prompts_batch_async( messages=[piece.to_message() for piece in response_pieces], objectives=tasks ) - return scores - def _process_scoring_results( self, *, @@ -1165,12 +1161,11 @@ def _normalize_score_to_float(self, score_value: Any) -> float: """ if isinstance(score_value, bool): return 1.0 if score_value else 0.0 - elif isinstance(score_value, (int, float)): + if isinstance(score_value, (int, float)): # Ensure value is between 0 and 1 return max(0.0, min(1.0, float(score_value))) - else: - self._logger.warning(f"Unexpected score type: {type(score_value)}, treating as 0.0") - return 0.0 + self._logger.warning(f"Unexpected score type: {type(score_value)}, treating as 0.0") + return 0.0 def _create_generation_result(self, context: FuzzerContext) -> FuzzerResult: """ @@ -1183,15 +1178,13 @@ def _create_generation_result(self, context: FuzzerContext) -> FuzzerResult: FuzzerResult: The generation result. """ # Create result with concrete fields - result = FuzzerResult( + return FuzzerResult( successful_templates=[node.template for node in context.new_prompt_nodes], jailbreak_conversation_ids=context.jailbreak_conversation_ids, total_queries=context.total_target_query_count, templates_explored=len(context.new_prompt_nodes), ) - return result - @overload async def execute_async( self, diff --git a/pyrit/identifiers/class_name_utils.py b/pyrit/identifiers/class_name_utils.py index 5c01b4f34b..f1a4d715a6 100644 --- a/pyrit/identifiers/class_name_utils.py +++ b/pyrit/identifiers/class_name_utils.py @@ -28,8 +28,7 @@ def class_name_to_snake_case(class_name: str, *, suffix: str = "") -> str: # Handle transitions like "XMLParser" -> "XML_Parser" name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", class_name) # Handle transitions like "getHTTPResponse" -> "get_HTTP_Response" - name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() - return name + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() def snake_case_to_class_name(snake_case_name: str, *, suffix: str = "") -> str: diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 52c8cbf6d8..e6d36120f2 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -343,7 +343,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories combined_conditions = " AND ".join(harm_conditions) - targeted_harm_categories_subquery = exists().where( + return exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.targeted_harm_categories.isnot(None), @@ -352,7 +352,6 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories text(f"ISJSON(targeted_harm_categories) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), ) ) - return targeted_harm_categories_subquery def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -376,14 +375,13 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: combined_conditions = " AND ".join(label_conditions) - labels_subquery = exists().where( + return exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), ) ) - return labels_subquery def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ diff --git a/pyrit/memory/central_memory.py b/pyrit/memory/central_memory.py index 240319415f..a933e73107 100644 --- a/pyrit/memory/central_memory.py +++ b/pyrit/memory/central_memory.py @@ -41,5 +41,4 @@ def get_memory_instance(cls) -> MemoryInterface: if cls._memory_instance: logger.info(f"Using existing memory instance: {type(cls._memory_instance).__name__}") return cls._memory_instance - else: - raise ValueError("Central memory instance has not been set. Use `set_memory_instance` to set it.") + raise ValueError("Central memory instance has not been set. Use `set_memory_instance` to set it.") diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index 7bcfc2e39d..5c85f30c74 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -46,12 +46,11 @@ def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> Embe """ if message_piece.converted_value_data_type == "text": embedding_response = self.embedding_model.generate_text_embedding(text=message_piece.converted_value) - embedding_data = EmbeddingDataEntry( + return EmbeddingDataEntry( embedding=embedding_response.data[0].embedding, embedding_type_name=self.embedding_model.__class__.__name__, id=message_piece.id, ) - return embedding_data raise ValueError("Only text data is supported for embedding.") diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 0fb3bb742e..5f127ee1d5 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -81,8 +81,7 @@ def load_dialect_impl(self, dialect: Any) -> Any: """ if dialect.name == "sqlite": return dialect.type_descriptor(CHAR(36)) - else: - return dialect.type_descriptor(Uuid()) + return dialect.type_descriptor(Uuid()) def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Optional[str]: """ diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index f437ec4978..9c881acf48 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -464,7 +464,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - targeted_harm_categories_subquery = exists().where( + return exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, # Exclude empty strings, None, and empty lists @@ -479,7 +479,6 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories ), ) ) - return targeted_harm_categories_subquery def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -493,7 +492,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - labels_subquery = exists().where( + return exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), @@ -502,7 +501,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ), ) ) - return labels_subquery def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 1f4e95d9cf..5e23e4844a 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -136,18 +136,17 @@ async def _piece_to_content_dict_async(self, piece: MessagePiece) -> dict[str, A if data_type == "text": return {"type": "text", "text": content} - elif data_type == "image_path": + if data_type == "image_path": # Convert local image to base64 data URL data_url = await convert_local_image_to_data_url(content) return {"type": "image_url", "image_url": {"url": data_url}} - elif data_type == "audio_path": + if data_type == "audio_path": # Convert local audio to base64 for input_audio format return await self._convert_audio_to_input_audio(content) - elif data_type == "url": + if data_type == "url": # Direct URL (typically for images) return {"type": "image_url", "image_url": {"url": content}} - else: - raise ValueError(f"Data type '{data_type}' is not yet supported for chat message content.") + raise ValueError(f"Data type '{data_type}' is not yet supported for chat message content.") async def _convert_audio_to_input_audio(self, audio_path: str) -> dict[str, Any]: """ diff --git a/pyrit/message_normalizer/conversation_context_normalizer.py b/pyrit/message_normalizer/conversation_context_normalizer.py index 24d6447743..8238c1e2be 100644 --- a/pyrit/message_normalizer/conversation_context_normalizer.py +++ b/pyrit/message_normalizer/conversation_context_normalizer.py @@ -81,8 +81,7 @@ def _format_piece_content(self, piece: MessagePiece) -> str: if piece.prompt_metadata and "context_description" in piece.prompt_metadata: description = piece.prompt_metadata["context_description"] return f"[{data_type.capitalize()} - {description}]" - else: - return f"[{data_type.capitalize()}]" + return f"[{data_type.capitalize()}]" # For text pieces, include both original and converted if different original = piece.original_value @@ -90,5 +89,4 @@ def _format_piece_content(self, piece: MessagePiece) -> str: if original != converted: return f"{converted} (original: {original})" - else: - return converted + return converted diff --git a/pyrit/message_normalizer/message_normalizer.py b/pyrit/message_normalizer/message_normalizer.py index 8a6b3900d7..0af9f26c14 100644 --- a/pyrit/message_normalizer/message_normalizer.py +++ b/pyrit/message_normalizer/message_normalizer.py @@ -105,15 +105,14 @@ async def apply_system_message_behavior(messages: List[Message], behavior: Syste """ if behavior == "keep": return messages - elif behavior == "squash": + if behavior == "squash": # Import here to avoid circular imports from pyrit.message_normalizer.generic_system_squash import ( GenericSystemSquashNormalizer, ) return await GenericSystemSquashNormalizer().normalize_async(messages) - elif behavior == "ignore": + if behavior == "ignore": return [msg for msg in messages if msg.role != "system"] - else: - # This should never happen due to Literal type, but handle it gracefully - raise ValueError(f"Unknown system message behavior: {behavior}") + # This should never happen due to Literal type, but handle it gracefully + raise ValueError(f"Unknown system message behavior: {behavior}") diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index b62e3b5234..69d1d43512 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -219,11 +219,10 @@ async def normalize_string_async(self, messages: List[Message]) -> str: # Convert ChatMessage objects to dicts messages_list = [msg.model_dump(exclude_none=True) for msg in chat_messages] - formatted_messages = str( + return str( self.tokenizer.apply_chat_template( messages_list, tokenize=False, add_generation_prompt=True, ) ) - return formatted_messages diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 4ace4e2bba..9b8d490e66 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -56,33 +56,30 @@ def data_serializer_factory( if value is not None: if data_type in ["text", "reasoning", "function_call", "tool_call", "function_call_output"]: return TextDataTypeSerializer(prompt_text=value, data_type=data_type) - elif data_type == "image_path": + if data_type == "image_path": return ImagePathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - elif data_type == "audio_path": + if data_type == "audio_path": return AudioPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - elif data_type == "video_path": + if data_type == "video_path": return VideoPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - elif data_type == "binary_path": + if data_type == "binary_path": return BinaryPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - elif data_type == "error": + if data_type == "error": return ErrorDataTypeSerializer(prompt_text=value) - elif data_type == "url": + if data_type == "url": return URLDataTypeSerializer(category=category, prompt_text=value, extension=extension) - else: - raise ValueError(f"Data type {data_type} not supported") - else: - if data_type == "image_path": - return ImagePathDataTypeSerializer(category=category, extension=extension) - elif data_type == "audio_path": - return AudioPathDataTypeSerializer(category=category, extension=extension) - elif data_type == "video_path": - return VideoPathDataTypeSerializer(category=category, extension=extension) - elif data_type == "binary_path": - return BinaryPathDataTypeSerializer(category=category, extension=extension) - elif data_type == "error": - return ErrorDataTypeSerializer(prompt_text="") - else: - raise ValueError(f"Data type {data_type} without prompt text not supported") + raise ValueError(f"Data type {data_type} not supported") + if data_type == "image_path": + return ImagePathDataTypeSerializer(category=category, extension=extension) + if data_type == "audio_path": + return AudioPathDataTypeSerializer(category=category, extension=extension) + if data_type == "video_path": + return VideoPathDataTypeSerializer(category=category, extension=extension) + if data_type == "binary_path": + return BinaryPathDataTypeSerializer(category=category, extension=extension) + if data_type == "error": + return ErrorDataTypeSerializer(prompt_text="") + raise ValueError(f"Data type {data_type} without prompt text not supported") class DataTypeSerializer(abc.ABC): diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index e1a6c01558..be1665a419 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -185,8 +185,7 @@ def normalize_scenario_name(scenario_name: str) -> str: # Convert snake_case to PascalCase # e.g., "content_harms" -> "ContentHarms" parts = scenario_name.split("_") - pascal_name = "".join(part.capitalize() for part in parts) - return pascal_name + return "".join(part.capitalize() for part in parts) # Already PascalCase or other format, return as-is return scenario_name diff --git a/pyrit/models/score.py b/pyrit/models/score.py index de754790b5..c8c04cce64 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -105,7 +105,7 @@ def get_value(self) -> bool | float: """ if self.score_type == "true_false": return self.score_value.lower() == "true" - elif self.score_type == "float_scale": + if self.score_type == "float_scale": return float(self.score_value) raise ValueError(f"Unknown scorer type: {self.score_type}") @@ -113,7 +113,7 @@ def get_value(self) -> bool | float: def validate(self, scorer_type: str, score_value: str) -> None: if scorer_type == "true_false" and str(score_value).lower() not in ["true", "false"]: raise ValueError(f"True False scorers must have a score value of 'true' or 'false' not {score_value}") - elif scorer_type == "float_scale": + if scorer_type == "float_scale": try: score = float(score_value) if not (0 <= score <= 1): diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 935438a627..4e94b55e1e 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -161,7 +161,7 @@ def _enforce_consistent_group_id(self) -> None: if len(existing_group_ids) > 1: raise ValueError("Inconsistent group IDs found across seeds.") - elif len(existing_group_ids) == 1: + if len(existing_group_ids) == 1: group_id = existing_group_ids.pop() for seed in self.seeds: seed.prompt_group_id = group_id @@ -308,8 +308,7 @@ def prepended_conversation(self) -> Optional[List[Message]]: return None return self._prompts_to_messages(prepended_prompts) - else: - return self._prompts_to_messages(list(self.prompts)) + return self._prompts_to_messages(list(self.prompts)) @property def next_message(self) -> Optional[Message]: diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 54deae56c1..ab19487b29 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -137,8 +137,7 @@ def _convert_to_path(self, path: Union[Path, str]) -> Path: """ Converts the path to a Path object if it's a string. """ - path = Path(path) if isinstance(path, str) else path - return path + return Path(path) if isinstance(path, str) else path class AzureBlobStorageIO(StorageIO): @@ -204,9 +203,8 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st + "enable delegation-based SAS authentication to connect to the storage account" ) raise - else: - logger.exception(msg=f"An unexpected error occurred: {exc}") - raise + logger.exception(msg=f"An unexpected error occurred: {exc}") + raise def parse_blob_url(self, file_path: str) -> tuple[str, str]: """Parses the blob URL to extract the container name and blob name.""" @@ -215,8 +213,7 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: container_name = parsed_url.path.split("/")[1] blob_name = "/".join(parsed_url.path.split("/")[2:]) return container_name, blob_name - else: - raise ValueError("Invalid blob URL") + raise ValueError("Invalid blob URL") async def read_file(self, path: Union[Path, str]) -> bytes: """ @@ -254,9 +251,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # Download the blob blob_stream = await blob_client.download_blob() - file_content = await blob_stream.readall() - - return file_content + return await blob_stream.readall() except Exception as exc: logger.exception(f"Failed to read file at {blob_name}: {exc}") diff --git a/pyrit/prompt_converter/bin_ascii_converter.py b/pyrit/prompt_converter/bin_ascii_converter.py index da48976520..2096c6f397 100644 --- a/pyrit/prompt_converter/bin_ascii_converter.py +++ b/pyrit/prompt_converter/bin_ascii_converter.py @@ -88,12 +88,11 @@ async def convert_word_async(self, word: str) -> str: """ if self._encoding_func == "hex": return word.encode("utf-8").hex().upper() - elif self._encoding_func == "quoted-printable": + if self._encoding_func == "quoted-printable": return binascii.b2a_qp(word.encode("utf-8")).decode("ascii") - elif self._encoding_func == "UUencode": + if self._encoding_func == "UUencode": return self._uuencode_chunk(word) - else: - raise ValueError(f"Unsupported encoding function: {self._encoding_func}") + raise ValueError(f"Unsupported encoding function: {self._encoding_func}") def _uuencode_chunk(self, text: str) -> str: """ @@ -128,10 +127,10 @@ def join_words(self, words: list[str]) -> str: if all_words_selected: if self._encoding_func == "hex": return "20".join(words) # 20 is the hex representation of space - elif self._encoding_func == "quoted-printable": + if self._encoding_func == "quoted-printable": # Quoted-printable uses =20 for space return "=20".join(words) - elif self._encoding_func == "UUencode": + if self._encoding_func == "UUencode": # UUencode: join with encoded space return "".join(words) # UUencode handles spaces within encoding return super().join_words(words=words) diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index 292908b2f6..a73ac431eb 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -213,16 +213,14 @@ def tree_to_json(node: Optional[TreeNode]) -> Optional[dict[str, Any]]: return json.dumps(tree_representation) def _encrypt_reverse(self, sentence: str) -> str: - reverse_sentence = " ".join(sentence.split(" ")[::-1]) - return reverse_sentence + return " ".join(sentence.split(" ")[::-1]) def _encrypt_odd_even(self, sentence: str) -> str: words = sentence.split() odd_words = words[::2] even_words = words[1::2] encrypted_words = odd_words + even_words - encrypted_sentence = " ".join(encrypted_words) - return encrypted_sentence + return " ".join(encrypted_words) def _encrypt_length(self, sentence: str) -> str: class WordData: diff --git a/pyrit/prompt_converter/human_in_the_loop_converter.py b/pyrit/prompt_converter/human_in_the_loop_converter.py index 7b98fa093e..f26adebb0f 100644 --- a/pyrit/prompt_converter/human_in_the_loop_converter.py +++ b/pyrit/prompt_converter/human_in_the_loop_converter.py @@ -97,10 +97,10 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text ).strip() if user_input == "1": return ConverterResult(output_text=prompt, output_type=input_type) - elif user_input == "2": + if user_input == "2": new_input = input("Enter new prompt to send: ") return await self.convert_async(prompt=new_input, input_type=input_type) - elif user_input == "3": + if user_input == "3": if self._converters: converters_str = str([converter.__class__.__name__ for converter in self._converters]) converter_index = -1 @@ -115,6 +115,5 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converter = self._converters[converter_index] new_result = await converter.convert_async(prompt=prompt, input_type=input_type) return await self.convert_async(prompt=new_result.output_text, input_type=new_result.output_type) - else: - raise ValueError("No converters were passed into the HumanInTheLoopConverter") + raise ValueError("No converters were passed into the HumanInTheLoopConverter") return ConverterResult(output_text=prompt, output_type=input_type) diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py index d01fee1dfd..54508d34d8 100644 --- a/pyrit/prompt_converter/insert_punctuation_converter.py +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -131,8 +131,7 @@ def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: if self._between_words: return self._insert_between_words(words, word_indices, num_insertions, punctuation_list) - else: - return self._insert_within_words(prompt, num_insertions, punctuation_list) + return self._insert_within_words(prompt, num_insertions, punctuation_list) def _insert_between_words( self, words: List[str], word_indices: List[int], num_insertions: int, punctuation_list: List[str] diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index a10339e250..4cb5900fb7 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -204,9 +204,8 @@ def _prepare_content(self, prompt: str) -> str: if isinstance(prompt, str): logger.debug("No template provided. Using raw prompt.") return prompt - else: - logger.error("Prompt must be a string when no template is provided.") - raise ValueError("Prompt must be a string when no template is provided.") + logger.error("Prompt must be a string when no template is provided.") + raise ValueError("Prompt must be a string when no template is provided.") def _generate_pdf(self, content: str) -> bytes: """ diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 7cd37d2310..c776297179 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -162,8 +162,7 @@ async def convert_tokens_async( return ConverterResult(output_text=prompt, output_type="text") async def _replace_text_match(self, match: str) -> ConverterResult: - result = await self.convert_async(prompt=match, input_type="text") - return result + return await self.convert_async(prompt=match, input_type="text") def _build_identifier(self) -> ComponentIdentifier: """ diff --git a/pyrit/prompt_converter/random_capital_letters_converter.py b/pyrit/prompt_converter/random_capital_letters_converter.py index 9d01aa20e3..7abe92dbc2 100644 --- a/pyrit/prompt_converter/random_capital_letters_converter.py +++ b/pyrit/prompt_converter/random_capital_letters_converter.py @@ -78,9 +78,7 @@ def generate_random_positions(self, total_length: int, set_number: int) -> list[ ) # Generate a list of unique random positions - random_positions = random.sample(range(total_length), set_number) - - return random_positions + return random.sample(range(total_length), set_number) def string_to_upper_case_by_percentage(self, percentage: float, prompt: str) -> str: """ diff --git a/pyrit/prompt_converter/selective_text_converter.py b/pyrit/prompt_converter/selective_text_converter.py index 5f3fa551fc..a9066ede37 100644 --- a/pyrit/prompt_converter/selective_text_converter.py +++ b/pyrit/prompt_converter/selective_text_converter.py @@ -181,8 +181,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text if self._is_word_level: return await self._convert_word_level_async(prompt=prompt) - else: - return await self._convert_char_level_async(prompt=prompt) + return await self._convert_char_level_async(prompt=prompt) async def _convert_word_level_async(self, *, prompt: str) -> ConverterResult: """ diff --git a/pyrit/prompt_converter/text_selection_strategy.py b/pyrit/prompt_converter/text_selection_strategy.py index cb4583a59f..4afbb3ac6e 100644 --- a/pyrit/prompt_converter/text_selection_strategy.py +++ b/pyrit/prompt_converter/text_selection_strategy.py @@ -333,17 +333,17 @@ def select_range(self, *, text: str) -> tuple[int, int]: if self._anchor == "start": return (0, selection_len) - elif self._anchor == "end": + if self._anchor == "end": return (text_len - selection_len, text_len) - elif self._anchor == "middle": + if self._anchor == "middle": start = (text_len - selection_len) // 2 return (start, start + selection_len) - else: # random - if self._seed is not None: - random.seed(self._seed) - max_start = max(0, text_len - selection_len) - start = random.randint(0, max_start) if max_start > 0 else 0 - return (start, start + selection_len) + # random + if self._seed is not None: + random.seed(self._seed) + max_start = max(0, text_len - selection_len) + start = random.randint(0, max_start) if max_start > 0 else 0 + return (start, start + selection_len) class RangeSelectionStrategy(TextSelectionStrategy): @@ -465,9 +465,8 @@ def select_words(self, *, words: List[str]) -> List[int]: if self._case_sensitive: return [i for i, word in enumerate(words) if word in self._keywords] - else: - keywords_lower = [k.lower() for k in self._keywords] - return [i for i, word in enumerate(words) if word.lower() in keywords_lower] + keywords_lower = [k.lower() for k in self._keywords] + return [i for i, word in enumerate(words) if word.lower() in keywords_lower] class WordProportionSelectionStrategy(WordSelectionStrategy): diff --git a/pyrit/prompt_converter/token_smuggling/base.py b/pyrit/prompt_converter/token_smuggling/base.py index 217a42bf18..d05e72765a 100644 --- a/pyrit/prompt_converter/token_smuggling/base.py +++ b/pyrit/prompt_converter/token_smuggling/base.py @@ -70,9 +70,8 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text summary, encoded = self.encode_message(message=prompt) logger.info(f"Encoded message summary: {summary}") return ConverterResult(output_text=encoded, output_type="text") - else: - decoded = self.decode_message(message=prompt) - return ConverterResult(output_text=decoded, output_type="text") + decoded = self.decode_message(message=prompt) + return ConverterResult(output_text=decoded, output_type="text") def input_supported(self, input_type: PromptDataType) -> bool: """ diff --git a/pyrit/prompt_converter/unicode_confusable_converter.py b/pyrit/prompt_converter/unicode_confusable_converter.py index e86b8d676e..c623cc25bb 100644 --- a/pyrit/prompt_converter/unicode_confusable_converter.py +++ b/pyrit/prompt_converter/unicode_confusable_converter.py @@ -169,7 +169,6 @@ def _confusable(self, char: str) -> str: confusable_options = confusable_characters(char) if not confusable_options or char == " ": return char - elif self._deterministic or len(confusable_options) == 1: + if self._deterministic or len(confusable_options) == 1: return str(confusable_options[-1]) - else: - return str(random.choice(confusable_options)) + return str(random.choice(confusable_options)) diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index b13fe44a06..586913ecad 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -149,9 +149,8 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st + "enable delegation-based SAS authentication to connect to the storage account" ) raise - else: - logger.exception(msg=f"An unexpected error occurred: {exc}") - raise + logger.exception(msg=f"An unexpected error occurred: {exc}") + raise def _parse_url(self) -> tuple[str, str]: """ diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 686c60a399..24c49299dd 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -243,7 +243,7 @@ async def _construct_http_body_async( # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will # be ignored. We only include commonly supported parameters here - model-specific parameters like # stop sequences should be passed via **param_kwargs since different models use different EOS tokens. - data = { + return { "input_data": { "input_string": messages_dict, "parameters": { @@ -256,8 +256,6 @@ async def _construct_http_body_async( } } - return data - def _get_headers(self) -> dict[str, str]: """ Headers for accessing inference endpoint deployed in AML. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 53e505600a..50ca68a886 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -120,14 +120,13 @@ def with_client( Returns: HTTPTarget: an instance of HTTPTarget """ - instance = cls( + return cls( http_request=http_request, prompt_regex_string=prompt_regex_string, callback_function=callback_function, max_requests_per_minute=max_requests_per_minute, client=client, ) - return instance def _inject_prompt_into_request(self, request: MessagePiece) -> str: """ diff --git a/pyrit/prompt_target/http_target/http_target_callback_functions.py b/pyrit/prompt_target/http_target/http_target_callback_functions.py index 888ed100e8..636b1a0b9d 100644 --- a/pyrit/prompt_target/http_target/http_target_callback_functions.py +++ b/pyrit/prompt_target/http_target/http_target_callback_functions.py @@ -71,10 +71,8 @@ def parse_using_regex(response: requests.Response) -> str: if match: if url: return url + match.group() - else: - return match.group() - else: - return str(response.content) + return match.group() + return str(response.content) return parse_using_regex diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index c7a5638f08..6eff804f66 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -371,7 +371,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: logger.info("Tokenizer has a chat template. Applying it to the input messages.") # Apply the chat template to format and tokenize the messages - tokenized_chat = cast( + return cast( BatchEncoding, self.tokenizer.apply_chat_template( messages, @@ -381,14 +381,12 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: return_dict=True, ), ).to(self.device) - return tokenized_chat - else: - error_message = ( - "Tokenizer does not have a chat template. " - "This model is not supported, as we only support instruct models with a chat template." - ) - logger.error(error_message) - raise ValueError(error_message) + error_message = ( + "Tokenizer does not have a chat template. " + "This model is not supported, as we only support instruct models with a chat template." + ) + logger.error(error_message) + raise ValueError(error_message) def _validate_request(self, *, message: Message) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 54e451512a..f660939b5a 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -497,8 +497,7 @@ async def _build_chat_messages_async(self, conversation: MutableSequence[Message """ if self._is_text_message_format(conversation): return self._build_chat_messages_for_text(conversation) - else: - return await self._build_chat_messages_for_multi_modal_async(conversation) + return await self._build_chat_messages_for_multi_modal_async(conversation) def _is_text_message_format(self, conversation: MutableSequence[Message]) -> bool: """ diff --git a/pyrit/prompt_target/openai/openai_error_handling.py b/pyrit/prompt_target/openai/openai_error_handling.py index 09b6d424bd..f70372a150 100644 --- a/pyrit/prompt_target/openai/openai_error_handling.py +++ b/pyrit/prompt_target/openai/openai_error_handling.py @@ -79,10 +79,9 @@ def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: return True # Heuristic: Azure sometimes uses other codes with policy-related content return "content_filter" in json.dumps(data).lower() - else: - # String-based heuristic search - lower = str(data).lower() - return "content_filter" in lower or "policy_violation" in lower or "moderation_blocked" in lower + # String-based heuristic search + lower = str(data).lower() + return "content_filter" in lower or "policy_violation" in lower or "moderation_blocked" in lower def _extract_error_payload(exc: Exception) -> Tuple[Union[dict[str, object], str], bool]: @@ -129,7 +128,7 @@ def _extract_error_payload(exc: Exception) -> Tuple[Union[dict[str, object], str if body is not None: if isinstance(body, dict): return body, _is_content_filter_error(body) - elif isinstance(body, str): + if isinstance(body, str): try: data = json.loads(body) return data, _is_content_filter_error(data) diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 89324d3d4e..d0caa44e1c 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -176,11 +176,10 @@ async def _send_generate_request_async(self, message: Message) -> Message: image_generation_args["style"] = self.style # Use unified error handler for consistent error handling - response = await self._handle_openai_request( + return await self._handle_openai_request( api_call=lambda: self._async_client.images.generate(**image_generation_args), request=message, ) - return response async def _send_edit_request_async(self, message: Message) -> Message: """ @@ -227,13 +226,11 @@ async def _send_edit_request_async(self, message: Message) -> Message: if self.style: image_edit_args["style"] = self.style - response = await self._handle_openai_request( + return await self._handle_openai_request( api_call=lambda: self._async_client.images.edit(**image_edit_args), request=message, ) - return response - async def _construct_message_from_response(self, response: Any, request: Any) -> Message: """ Construct a Message from an ImagesResponse. diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 143e36d81c..8c4b98d7e7 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -507,13 +507,13 @@ async def receive_events(self, conversation_id: str) -> RealtimeTargetResult: logger.debug("Received response.done - finishing normally") break - elif event_type == "error": + if event_type == "error": error_message = event.error.message if hasattr(event.error, "message") else str(event.error) error_type = event.error.type if hasattr(event.error, "type") else "unknown" logger.error(f"Received 'error' event: [{error_type}] {error_message}") raise RuntimeError(f"Server error: [{error_type}] {error_message}") - elif event_type in ["response.audio.delta", "response.output_audio.delta"]: + if event_type in ["response.audio.delta", "response.output_audio.delta"]: audio_data = base64.b64decode(event.delta) result.audio_bytes += audio_data logger.debug(f"Decoded {len(audio_data)} bytes of audio data") diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 1db5661991..96bc9e997b 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -505,11 +505,10 @@ def model_dump_json(self) -> str: retry_after = _extract_retry_after_from_exception(e) logger.warning(f"429 via APIStatusError request_id={request_id} retry_after={retry_after}") raise RateLimitException() - else: - logger.exception( - f"APIStatusError request_id={request_id} status={getattr(e, 'status_code', None)} error={e}" - ) - raise + logger.exception( + f"APIStatusError request_id={request_id} status={getattr(e, 'status_code', None)} error={e}" + ) + raise except (APITimeoutError, APIConnectionError) as e: # Transient infrastructure errors - these are retryable request_id = _extract_request_id_from_exception(e) diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 82078c3f1d..3e915a37d5 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -379,7 +379,7 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> # Save the video to storage (include video.id for chaining remixes) return await self._save_video_response(request=request, video_data=video_content, video_id=video.id) - elif video.status == "failed": + if video.status == "failed": # Handle failed video generation (non-content-filter) error_message = str(video.error) if video.error else "Video generation failed" logger.error(f"Video generation failed: {error_message}") @@ -391,16 +391,15 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> response_type="error", error="processing", ) - else: - # Unexpected status - error_message = f"Video generation ended with unexpected status: {video.status}" - logger.error(error_message) - return construct_response_from_request( - request=request, - response_text_pieces=[error_message], - response_type="error", - error="unknown", - ) + # Unexpected status + error_message = f"Video generation ended with unexpected status: {video.status}" + logger.error(error_message) + return construct_response_from_request( + request=request, + response_text_pieces=[error_message], + response_type="error", + error="unknown", + ) async def _save_video_response( self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None @@ -427,15 +426,13 @@ async def _save_video_response( prompt_metadata: Optional[dict[str, Union[str, int]]] = {"video_id": video_id} if video_id else None # Construct response - response_entry = construct_response_from_request( + return construct_response_from_request( request=request, response_text_pieces=[video_path], response_type="video_path", prompt_metadata=prompt_metadata, ) - return response_entry - def _validate_request(self, *, message: Message) -> None: """ Validate the request message. diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 98e2f1e97d..73fe98bc44 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -159,18 +159,18 @@ def _get_selectors(self) -> CopilotSelectors: plus_button_dropdown_selector='button[aria-label="Open"]', file_picker_selector='button[aria-label="Add images or files"]', ) - else: # M365 Copilot - return CopilotSelectors( - input_selector='span[role="textbox"][contenteditable="true"][aria-label="Message Copilot"]', - send_button_selector='button[type="submit"]', - ai_messages_selector='div[data-testid="copilot-message-div"]', - ai_messages_group_selector=( - 'div[data-testid="copilot-message-div"] > div > div > div > div > div > div > div > div > div > div' - ), - text_content_selector="div > p", - plus_button_dropdown_selector='button[aria-label="Add content"]', - file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', - ) + # M365 Copilot + return CopilotSelectors( + input_selector='span[role="textbox"][contenteditable="true"][aria-label="Message Copilot"]', + send_button_selector='button[type="submit"]', + ai_messages_selector='div[data-testid="copilot-message-div"]', + ai_messages_group_selector=( + 'div[data-testid="copilot-message-div"] > div > div > div > div > div > div > div > div > div > div' + ), + text_content_selector="div > p", + plus_button_dropdown_selector='button[aria-label="Add content"]', + file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', + ) async def send_prompt_async(self, *, message: Message) -> list[Message]: """ @@ -336,9 +336,8 @@ async def _extract_content_if_ready_async( if content_ready: logger.debug("Content is ready!") return test_content - else: - logger.debug("Message exists but content not ready yet, continuing to wait...") - return None + logger.debug("Message exists but content not ready yet, continuing to wait...") + return None except Exception as e: # Continue waiting if extraction fails logger.debug(f"Error checking content readiness: {e}") @@ -713,12 +712,11 @@ def _assemble_response( # Single text response - maintain backward compatibility logger.debug(f"Returning single text response: '{response_pieces[0][0]}'") return response_pieces[0][0] - elif response_pieces: + if response_pieces: # Multimodal or multiple pieces logger.debug(f"Returning {len(response_pieces)} response pieces") return response_pieces - else: - return "" + return "" async def _extract_multimodal_content_async( self, selectors: CopilotSelectors, initial_group_count: int = 0 @@ -760,8 +758,7 @@ async def _extract_multimodal_content_async( # Return appropriate format, with fallback if needed if response_pieces: return self._assemble_response(response_pieces=response_pieces) - else: - return await self._extract_fallback_text_async(ai_message_groups=ai_message_groups) + return await self._extract_fallback_text_async(ai_message_groups=ai_message_groups) async def _send_text_async(self, *, text: str, input_selector: str) -> None: """ diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index f481ac31b0..0571cf216b 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -81,8 +81,7 @@ def create_instance(self, **kwargs: object) -> T: if self.factory is not None: return self.factory(**merged_kwargs) - else: - return self.registered_class(**merged_kwargs) + return self.registered_class(**merged_kwargs) class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 17e5d8b680..9139e9b3c3 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -260,9 +260,8 @@ async def initialize_async( # Validate that the stored scenario matches current configuration if self._validate_stored_scenario(stored_result=existing_result): return # Valid match - skip creating new scenario result - else: - # Validation failed - will create new scenario result - self._scenario_result_id = None + # Validation failed - will create new scenario result + self._scenario_result_id = None else: logger.warning( f"Scenario result ID {self._scenario_result_id} not found in memory. Creating new scenario result." @@ -564,8 +563,7 @@ async def run_async(self) -> ScenarioResult: last_exception = None for retry_attempt in range(self._max_retries + 1): # +1 for initial attempt try: - result = await self._execute_scenario_async() - return result + return await self._execute_scenario_async() except Exception as e: last_exception = e @@ -584,14 +582,13 @@ async def run_async(self) -> ScenarioResult: ) # Continue to next iteration for retry continue - else: - # No more retries, log final failure - logger.error( - f"Scenario '{self._name}' failed after {current_tries} attempts " - f"(initial + {self._max_retries} retries) with error: {str(e)}. Giving up.", - exc_info=True, - ) - raise + # No more retries, log final failure + logger.error( + f"Scenario '{self._name}' failed after {current_tries} attempts " + f"(initial + {self._max_retries} retries) with error: {str(e)}. Giving up.", + exc_info=True, + ) + raise # This should never be reached, but just in case if last_exception: @@ -645,8 +642,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if scenario_results: return scenario_results[0] - else: - raise ValueError(f"Scenario result with ID {scenario_result_id} not found") + raise ValueError(f"Scenario result with ID {scenario_result_id} not found") logger.info( f"Scenario '{self._name}' has {len(remaining_attacks)} atomic attacks " @@ -715,11 +711,10 @@ async def _execute_scenario_async(self) -> ScenarioResult: f"in scenario '{self._name}': {incomplete_count} of {incomplete_count + completed_count} " f"objectives incomplete. First failure: {atomic_results.incomplete_objectives[0][1]}" ) from atomic_results.incomplete_objectives[0][1] - else: - logger.info( - f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " - f"{len(atomic_results.completed_results)} results" - ) + logger.info( + f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " + f"{len(atomic_results.completed_results)} results" + ) except Exception as e: # Exception was raised either by run_async or by our check above diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index d1f1cdceb6..eaeea66f62 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -263,9 +263,7 @@ def prepare_scenario_strategies( ) # Normalize compositions (expands aggregates, validates compositions) - normalized = ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) - - return normalized + return ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) @classmethod def supports_composition(cls: type[T]) -> bool: diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 07f7b6e94b..c7d14e412a 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -196,9 +196,8 @@ def _get_rate_color(self, rate: int) -> str: """ if rate >= 75: return str(Fore.RED) # High success (bad for security) - elif rate >= 50: + if rate >= 50: return str(Fore.YELLOW) # Medium success - elif rate >= 25: + if rate >= 25: return str(Fore.CYAN) # Low success - else: - return str(Fore.GREEN) # Very low success (good for security) + return str(Fore.GREEN) # Very low success (good for security) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 869760607e..4e77342c2c 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -199,7 +199,7 @@ def _get_default_objective_scorer(self) -> TrueFalseScorer: Returns: TrueFalseScorer: A scorer that returns True when the model does NOT refuse. """ - refusal_scorer = TrueFalseInverterScorer( + return TrueFalseInverterScorer( scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT"), @@ -208,7 +208,6 @@ def _get_default_objective_scorer(self) -> TrueFalseScorer: ) ) ) - return refusal_scorer def _create_adversarial_target(self) -> OpenAIChatTarget: """ diff --git a/pyrit/scenario/scenarios/airt/psychosocial_scenario.py b/pyrit/scenario/scenarios/airt/psychosocial_scenario.py index 0928513f89..bf1f8f058e 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial_scenario.py +++ b/pyrit/scenario/scenarios/airt/psychosocial_scenario.py @@ -487,7 +487,7 @@ def _create_attacks_for_strategy( scoring_config=scoring_config, seed_groups=seed_groups, ) - elif strategy == "multi_turn": + if strategy == "multi_turn": return [ self._create_multi_turn_attack( scoring_config=scoring_config, @@ -495,8 +495,7 @@ def _create_attacks_for_strategy( seed_groups=seed_groups, ) ] - else: - raise ValueError(f"Unknown strategy: {strategy}") + raise ValueError(f"Unknown strategy: {strategy}") def _create_single_turn_attacks( self, diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index e9cd4e046a..1d4b774f29 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -339,5 +339,4 @@ async def _get_base64_image_data(self, message_piece: MessagePiece) -> str: image_serializer = data_serializer_factory( category="prompt-memory-entries", value=image_path, data_type="image_path", extension=ext ) - base64_encoded_data = await image_serializer.read_data_base64() - return base64_encoded_data + return await image_serializer.read_data_base64() diff --git a/pyrit/score/float_scale/plagiarism_scorer.py b/pyrit/score/float_scale/plagiarism_scorer.py index 71a2b4d9cc..444e15efb6 100644 --- a/pyrit/score/float_scale/plagiarism_scorer.py +++ b/pyrit/score/float_scale/plagiarism_scorer.py @@ -147,27 +147,23 @@ def _plagiarism_score( # Compute the LCS metric (normalized by reference length) if metric.value == "lcs": lcs_len = self._lcs_length(tokens_reference, tokens_response) - score = lcs_len / reference_len - return score + return lcs_len / reference_len # Compute the Levenshtein metric (normalized by max length) - elif metric.value == "levenshtein": + if metric.value == "levenshtein": lev_dist = self._levenshtein_distance(tokens_reference, tokens_response) max_len = max(reference_len, response_len) - score = 1 - (lev_dist / max_len) - return score + return 1 - (lev_dist / max_len) # Compute the Jaccard metric (normalized by number of n-grams in reference) - elif metric.value == "jaccard": + if metric.value == "jaccard": ref_ngrams = self._ngram_set(tokens_reference, n) if reference_len >= n else set() res_ngrams = self._ngram_set(tokens_response, n) if response_len >= n else set() if not ref_ngrams: return 0.0 - score = len(ref_ngrams & res_ngrams) / len(ref_ngrams) - return score + return len(ref_ngrams & res_ngrams) / len(ref_ngrams) - else: - raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'") + raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'") async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: """ diff --git a/pyrit/score/human/human_in_the_loop_gradio.py b/pyrit/score/human/human_in_the_loop_gradio.py index 9f01ec4824..3237ec7028 100644 --- a/pyrit/score/human/human_in_the_loop_gradio.py +++ b/pyrit/score/human/human_in_the_loop_gradio.py @@ -86,8 +86,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op asyncio.CancelledError: If the scoring operation is cancelled. """ try: - score = await asyncio.to_thread(self.retrieve_score, message_piece, objective=objective) - return score + return await asyncio.to_thread(self.retrieve_score, message_piece, objective=objective) except asyncio.CancelledError: self._rpc_server.stop() raise diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 35ef6cffbe..bce9566d77 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -78,16 +78,15 @@ def _get_quality_color( if higher_is_better: if value >= good_threshold: return Fore.GREEN # type: ignore[no-any-return] - elif value < bad_threshold: - return Fore.RED # type: ignore[no-any-return] - return Fore.CYAN # type: ignore[no-any-return] - else: - # Lower is better (e.g., MAE, score time) - if value <= good_threshold: - return Fore.GREEN # type: ignore[no-any-return] - elif value > bad_threshold: + if value < bad_threshold: return Fore.RED # type: ignore[no-any-return] return Fore.CYAN # type: ignore[no-any-return] + # Lower is better (e.g., MAE, score time) + if value <= good_threshold: + return Fore.GREEN # type: ignore[no-any-return] + if value > bad_threshold: + return Fore.RED # type: ignore[no-any-return] + return Fore.CYAN # type: ignore[no-any-return] def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ diff --git a/pyrit/score/score_utils.py b/pyrit/score/score_utils.py index f0465067d2..7a8258cac0 100644 --- a/pyrit/score/score_utils.py +++ b/pyrit/score/score_utils.py @@ -83,7 +83,6 @@ def normalize_score_to_float(score: Optional[Score]) -> float: score_value = score.get_value() if isinstance(score_value, bool): return 1.0 if score_value else 0.0 - elif isinstance(score_value, (int, float)): + if isinstance(score_value, (int, float)): return float(score_value) - else: - return 0.0 + return 0.0 diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 151cfd0992..f2ffa8e60c 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -87,10 +87,9 @@ def scorer_type(self) -> ScoreType: if isinstance(self, TrueFalseScorer): return "true_false" - elif isinstance(self, FloatScaleScorer): + if isinstance(self, FloatScaleScorer): return "float_scale" - else: - return "unknown" + return "unknown" @property def _memory(self) -> MemoryInterface: @@ -475,8 +474,7 @@ def scale_value_float(self, value: float, min_value: float, max_value: float) -> if max_value == min_value: return 0.0 - normalized_value = (value - min_value) / (max_value - min_value) - return normalized_value + return (value - min_value) / (max_value - min_value) @pyrit_json_retry async def _score_value_with_llm( @@ -669,7 +667,7 @@ def _extract_objective_from_response(self, response: Message) -> str: last_prompt = max(conversation, key=lambda x: x.sequence) # Every text message piece from the last turn - last_turn_text = "\n".join( + return "\n".join( [ piece.original_value for piece in conversation @@ -677,8 +675,6 @@ def _extract_objective_from_response(self, response: Message) -> str: ] ) - return last_turn_text - @staticmethod async def score_response_async( *, diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index 7c7a2e9b80..bee803e5bd 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -323,12 +323,11 @@ def _should_skip_evaluation( f"(requested {num_scorer_trials}). Skipping evaluation." ) return (True, existing) - else: - logger.info( - f"Existing metrics have fewer trials ({existing.num_scorer_trials} < {num_scorer_trials}). " - f"Will re-run evaluation with more trials and replace existing entry." - ) - return (False, None) + logger.info( + f"Existing metrics have fewer trials ({existing.num_scorer_trials} < {num_scorer_trials}). " + f"Will re-run evaluation with more trials and replace existing entry." + ) + return (False, None) except Exception as e: logger.warning(f"Error checking for existing metrics: {e}") diff --git a/pyrit/score/true_false/decoding_scorer.py b/pyrit/score/true_false/decoding_scorer.py index c702757cb0..2a4e765f5e 100644 --- a/pyrit/score/true_false/decoding_scorer.py +++ b/pyrit/score/true_false/decoding_scorer.py @@ -102,7 +102,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op match_found = True break - score = [ + return [ Score( score_value=str(match_found), score_value_description="", @@ -115,5 +115,3 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op objective=objective, ) ] - - return score diff --git a/pyrit/score/true_false/markdown_injection.py b/pyrit/score/true_false/markdown_injection.py index 3a677b98de..9964b10438 100644 --- a/pyrit/score/true_false/markdown_injection.py +++ b/pyrit/score/true_false/markdown_injection.py @@ -75,7 +75,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op pattern = r"!\[.*?\]\((.*?)\)|!\[.*?\]\[(.*?)\]" matches = re.findall(pattern, text) - scores = [ + return [ Score( score_value=str(True if matches else False), score_value_description="True if markdown injection is detected, else False.", @@ -88,5 +88,3 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op objective=objective, ) ] - - return scores diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index e2993b3628..5d6010376a 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -91,7 +91,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op matching_text = text break - scores = [ + return [ Score( score_value=str(result), score_value_description="", @@ -108,5 +108,3 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op objective=objective, ) ] - - return scores diff --git a/pyrit/score/true_false/substring_scorer.py b/pyrit/score/true_false/substring_scorer.py index 9dfbe4d502..38ab69a30a 100644 --- a/pyrit/score/true_false/substring_scorer.py +++ b/pyrit/score/true_false/substring_scorer.py @@ -81,7 +81,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op """ substring_present = self._text_matcher.is_match(target=self._substring, text=message_piece.converted_value) - score = [ + return [ Score( score_value=str(substring_present), score_value_description="", @@ -94,5 +94,3 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op objective=objective, ) ] - - return score diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 8df8e121b6..ec0bf7cdb3 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -219,7 +219,7 @@ def wait_for_score(self) -> Score: if score_ref is None: return None # Pass instance variables of reflected RPyC Score object as args to PyRIT Score object - score = Score( + return Score( score_value=score_ref.score_value, score_type=score_ref.score_type, score_category=score_ref.score_category, @@ -233,8 +233,6 @@ def wait_for_score(self) -> Score: ), ) - return score - def wait_for_client(self) -> None: """ Wait for the client to be ready to receive messages. diff --git a/tests/integration/targets/test_targets_and_secrets.py b/tests/integration/targets/test_targets_and_secrets.py index 3c2493667f..8084e4a931 100644 --- a/tests/integration/targets/test_targets_and_secrets.py +++ b/tests/integration/targets/test_targets_and_secrets.py @@ -55,8 +55,7 @@ async def _assert_can_send_prompt(target, check_if_llm_interpreted_request=True, def valid_response(resp: str) -> bool: if check_if_llm_interpreted_request: return "test" in resp.strip().lower() - else: - return True + return True attempt = 0 while attempt < max_retries: diff --git a/tests/unit/analytics/test_conversation_analytics.py b/tests/unit/analytics/test_conversation_analytics.py index 79a62def69..a2525c7be7 100644 --- a/tests/unit/analytics/test_conversation_analytics.py +++ b/tests/unit/analytics/test_conversation_analytics.py @@ -14,8 +14,7 @@ @pytest.fixture def mock_memory_interface(): - memory_interface = MagicMock(spec=MemoryInterface) - return memory_interface + return MagicMock(spec=MemoryInterface) @pytest.fixture diff --git a/tests/unit/converter/test_add_text_image_converter.py b/tests/unit/converter/test_add_text_image_converter.py index bd529aa72f..c76ed72535 100644 --- a/tests/unit/converter/test_add_text_image_converter.py +++ b/tests/unit/converter/test_add_text_image_converter.py @@ -16,8 +16,7 @@ def text_image_converter_sample_image_bytes(): img = Image.new("RGB", (100, 100), color=(125, 125, 125)) img_bytes = BytesIO() img.save(img_bytes, format="PNG") - img_bytes = img_bytes.getvalue() - return img_bytes + return img_bytes.getvalue() def test_add_text_image_converter_initialization(): diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 888fdb641f..ea2fb2be9c 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -64,7 +64,7 @@ class TestAttackContext(AttackContext): @pytest.fixture def sample_attack_result(): """Create a sample AttackResult for testing""" - result = AttackResult( + return AttackResult( conversation_id="test-conversation-id", objective="Test objective", outcome=AttackOutcome.SUCCESS, @@ -72,7 +72,6 @@ def sample_attack_result(): execution_time_ms=0, executed_turns=1, ) - return result @pytest.fixture @@ -106,7 +105,7 @@ async def _setup_async(self, *, context): pass async def _perform_async(self, *, context): - result = AttackResult( + return AttackResult( conversation_id="test-conversation-id", objective="Test objective", outcome=AttackOutcome.SUCCESS, @@ -114,7 +113,6 @@ async def _perform_async(self, *, context): execution_time_ms=0, executed_turns=1, ) - return result async def _teardown_async(self, *, context): pass @@ -487,14 +485,13 @@ async def _setup_async(self, *, context): pass async def _perform_async(self, *, context): - result = AttackResult( + return AttackResult( conversation_id="test-conversation-id", objective="Test objective", outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", executed_turns=1, ) - return result async def _teardown_async(self, *, context): pass @@ -531,7 +528,7 @@ async def _setup_async(self, *, context): pass async def _perform_async(self, *, context): - result = AttackResult( + return AttackResult( conversation_id="test-conversation-id", objective="Test objective", outcome=AttackOutcome.SUCCESS, @@ -539,7 +536,6 @@ async def _perform_async(self, *, context): execution_time_ms=0, executed_turns=1, ) - return result async def _teardown_async(self, *, context): pass diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 83d44f8672..cf1f1073ca 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -1132,7 +1132,7 @@ def node_components(self, attack_builder): prompt_normalizer = MagicMock() prompt_normalizer.send_prompt_async = AsyncMock(return_value=None) - components = { + return { "objective_target": builder.objective_target, "adversarial_chat": builder.adversarial_chat, "objective_scorer": builder.objective_scorer, @@ -1150,7 +1150,6 @@ def node_components(self, attack_builder): "parent_id": None, "prompt_normalizer": prompt_normalizer, } - return components def test_node_initialization(self, node_components): """Test _TreeOfAttacksNode initialization.""" @@ -1308,19 +1307,18 @@ async def normalizer_side_effect(*args, **kwargs): ) ] ) - else: - # Return normal response for objective target - return Message( - message_pieces=[ - MessagePiece( - role="assistant", - original_value="Target response", - converted_value="Target response", - conversation_id=node.objective_target_conversation_id, - id=str(uuid.uuid4()), - ) - ] - ) + # Return normal response for objective target + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value="Target response", + converted_value="Target response", + conversation_id=node.objective_target_conversation_id, + id=str(uuid.uuid4()), + ) + ] + ) mock_normalizer.send_prompt_async = AsyncMock(side_effect=normalizer_side_effect) diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py index 7b78eed2df..23a348ff40 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py @@ -72,8 +72,7 @@ def template_converters(self, scoring_target) -> list[FuzzerConverter]: @pytest.fixture def mock_scorer(self) -> MagicMock: """Mock scorer for testing.""" - scorer = MagicMock(TrueFalseScorer) - return scorer + return MagicMock(TrueFalseScorer) @pytest.fixture def fuzzer_context(self, simple_prompts: list[str], simple_prompt_templates: list[str]) -> FuzzerContext: diff --git a/tests/unit/models/test_embedding_response.py b/tests/unit/models/test_embedding_response.py index 24e737e696..fe39fa102e 100644 --- a/tests/unit/models/test_embedding_response.py +++ b/tests/unit/models/test_embedding_response.py @@ -11,24 +11,22 @@ @pytest.fixture def my_embedding() -> EmbeddingResponse: - embedding = EmbeddingResponse( + return EmbeddingResponse( model="test", object="test", usage=EmbeddingUsageInformation(prompt_tokens=0, total_tokens=0), data=[EmbeddingData(embedding=[0.0], index=0, object="embedding")], ) - return embedding @pytest.fixture def my_embedding_data() -> dict: - data = { + return { "model": "test", "object": "test", "usage": {"prompt_tokens": 0, "total_tokens": 0}, "data": [{"embedding": [0.0], "index": 0, "object": "embedding"}], } - return data def test_can_save_embeddings(my_embedding: EmbeddingResponse): diff --git a/tests/unit/scenarios/test_cyber.py b/tests/unit/scenarios/test_cyber.py index 88970baa9f..736ed03158 100644 --- a/tests/unit/scenarios/test_cyber.py +++ b/tests/unit/scenarios/test_cyber.py @@ -68,8 +68,7 @@ def slow_cyberstrategy(): def malware_prompts(): """The default malware prompts.""" malware_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" - seed_prompts = list(SeedDataset.from_yaml_file(malware_path / "malware.prompt").get_values()) - return seed_prompts + return list(SeedDataset.from_yaml_file(malware_path / "malware.prompt").get_values()) @pytest.fixture diff --git a/tests/unit/scenarios/test_leakage_scenario.py b/tests/unit/scenarios/test_leakage_scenario.py index 045f536a00..357ea296dd 100644 --- a/tests/unit/scenarios/test_leakage_scenario.py +++ b/tests/unit/scenarios/test_leakage_scenario.py @@ -78,8 +78,7 @@ def role_play_strategy(): def leakage_prompts(): """The default leakage prompts.""" leakage_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" - seed_prompts = list(SeedDataset.from_yaml_file(leakage_path / "leakage.prompt").get_values()) - return seed_prompts + return list(SeedDataset.from_yaml_file(leakage_path / "leakage.prompt").get_values()) @pytest.fixture diff --git a/tests/unit/scenarios/test_scenario_partial_results.py b/tests/unit/scenarios/test_scenario_partial_results.py index 9ac709eaca..4a2755da15 100644 --- a/tests/unit/scenarios/test_scenario_partial_results.py +++ b/tests/unit/scenarios/test_scenario_partial_results.py @@ -145,18 +145,17 @@ async def mock_run(*args, **kwargs): save_attack_results_to_memory(completed) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) - else: - # Retry: complete the remaining objective - completed = [ - AttackResult( - conversation_id="conv-3", - objective="obj3", - outcome=AttackOutcome.SUCCESS, - executed_turns=1, - ) - ] - save_attack_results_to_memory(completed) - return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) + # Retry: complete the remaining objective + completed = [ + AttackResult( + conversation_id="conv-3", + objective="obj3", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + ] + save_attack_results_to_memory(completed) + return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) atomic_attack.run_async = mock_run @@ -261,21 +260,20 @@ async def mock_run(*args, **kwargs): save_attack_results_to_memory(completed) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) - else: - # Retry: complete remaining objectives - completed = [ - AttackResult( - conversation_id=f"conv-{i}", - objective=f"obj{i}", - outcome=AttackOutcome.SUCCESS, - executed_turns=1, - ) - for i in [4, 5] - ] + # Retry: complete remaining objectives + completed = [ + AttackResult( + conversation_id=f"conv-{i}", + objective=f"obj{i}", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + for i in [4, 5] + ] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed) - return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) + return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) atomic_attack.run_async = mock_run @@ -335,25 +333,22 @@ async def mock_run(*args, **kwargs): save_attack_results_to_memory(completed) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) - else: - # All other attempts succeed fully - completed = [ - AttackResult( - conversation_id=f"conv-{obj}", - objective=obj, - outcome=AttackOutcome.SUCCESS, - executed_turns=1, - ) - for obj in ( - attack1 - if attack_name == "attack_1" - else (attack2 if attack_name == "attack_2" else attack3) - ).objectives - ] + # All other attempts succeed fully + completed = [ + AttackResult( + conversation_id=f"conv-{obj}", + objective=obj, + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + for obj in ( + attack1 if attack_name == "attack_1" else (attack2 if attack_name == "attack_2" else attack3) + ).objectives + ] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed) - return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) + return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) return mock_run diff --git a/tests/unit/scenarios/test_scenario_retry.py b/tests/unit/scenarios/test_scenario_retry.py index 72fb04852b..7f9d09ba37 100644 --- a/tests/unit/scenarios/test_scenario_retry.py +++ b/tests/unit/scenarios/test_scenario_retry.py @@ -415,12 +415,11 @@ async def mock_run_with_partial_completion(*args, **kwargs): results = [create_attack_result(i, objective=f"obj{i}") for i in [1, 2]] save_attack_results_to_memory(results) raise Exception("Failed after 2 objectives") - else: - # Retry: should only execute remaining objectives (obj3, obj4) - executed_objectives.extend(["obj3", "obj4"]) - results = [create_attack_result(i, objective=f"obj{i}") for i in [3, 4]] - save_attack_results_to_memory(results) - return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) + # Retry: should only execute remaining objectives (obj3, obj4) + executed_objectives.extend(["obj3", "obj4"]) + results = [create_attack_result(i, objective=f"obj{i}") for i in [3, 4]] + save_attack_results_to_memory(results) + return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) atomic_attack.run_async = mock_run_with_partial_completion @@ -466,18 +465,16 @@ async def mock_run_attack2(*args, **kwargs): results = [create_attack_result(2, objective="objective2")] save_attack_results_to_memory(results) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) - else: - raise AssertionError("Attack 2 should not be retried after completion") + raise AssertionError("Attack 2 should not be retried after completion") # Attack 3: Fails on first attempt, succeeds on retry async def mock_run_attack3(*args, **kwargs): call_count["attack_3"] += 1 if call_count["attack_3"] == 1: raise Exception("Attack 3 failed on first attempt") - else: - results = [create_attack_result(3, objective="objective3")] - save_attack_results_to_memory(results) - return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) + results = [create_attack_result(3, objective="objective3")] + save_attack_results_to_memory(results) + return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) attack1.run_async = mock_run_attack1 attack2.run_async = mock_run_attack2 diff --git a/tests/unit/score/test_scorer_evaluator.py b/tests/unit/score/test_scorer_evaluator.py index 35f6a47f71..46cccaa543 100644 --- a/tests/unit/score/test_scorer_evaluator.py +++ b/tests/unit/score/test_scorer_evaluator.py @@ -604,7 +604,7 @@ async def test_run_evaluation_async_combines_dataset_versions_with_duplicates( entry = HarmHumanLabeledEntry(responses, [0.2], "hate_speech") def make_dataset(version, harm_def_version): - dataset = HumanLabeledDataset( + return HumanLabeledDataset( name="test", metrics_type=MetricsType.HARM, entries=[entry], @@ -612,7 +612,6 @@ def make_dataset(version, harm_def_version): harm_definition="hate_speech.yaml", harm_definition_version=harm_def_version, ) - return dataset # All three files have dataset_version "1.0" - should concatenate to "1.0_1.0_1.0" # All have same harm_definition_version "1.0" - should stay as "1.0" (unique) diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 68db0d8d05..a395b0ea08 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -118,7 +118,7 @@ async def test_initialize_async_deduplicates_datasets(self) -> None: def get_scenario_side_effect(name: str): if name == "scenario1": return mock_scenario1 - elif name == "scenario2": + if name == "scenario2": return mock_scenario2 return None @@ -158,7 +158,7 @@ async def test_initialize_async_handles_scenario_errors(self) -> None: def get_scenario_side_effect(name: str): if name == "good_scenario": return mock_scenario_good - elif name == "bad_scenario": + if name == "bad_scenario": return mock_scenario_bad return None diff --git a/tests/unit/target/test_azure_ml_chat_target.py b/tests/unit/target/test_azure_ml_chat_target.py index c17162ed35..abd29f9fb1 100644 --- a/tests/unit/target/test_azure_ml_chat_target.py +++ b/tests/unit/target/test_azure_ml_chat_target.py @@ -23,13 +23,12 @@ def sample_conversations() -> MutableSequence[MessagePiece]: @pytest.fixture def aml_online_chat(patch_central_database) -> AzureMLChatTarget: - aml_online_chat = AzureMLChatTarget( + return AzureMLChatTarget( endpoint="http://aml-test-endpoint.com", api_key="valid_api_key", extra_param1="sample", extra_param2=1.0, ) - return aml_online_chat def test_initialization_with_required_parameters( diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index e607de9e88..e75dfdc278 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -16,8 +16,7 @@ @pytest.fixture def mock_callback_function() -> Callable: - parsing_function = get_http_target_json_response_callback_function(key="mock_key") - return parsing_function + return get_http_target_json_response_callback_function(key="mock_key") @pytest.fixture diff --git a/tests/unit/target/test_http_target_parsing.py b/tests/unit/target/test_http_target_parsing.py index 17598d7546..5840c4e3ca 100644 --- a/tests/unit/target/test_http_target_parsing.py +++ b/tests/unit/target/test_http_target_parsing.py @@ -20,8 +20,7 @@ @pytest.fixture def mock_callback_function() -> Callable: - parsing_function = get_http_target_json_response_callback_function(key="mock_key") - return parsing_function + return get_http_target_json_response_callback_function(key="mock_key") @pytest.fixture diff --git a/tests/unit/target/test_openai_response_target_function_chaining.py b/tests/unit/target/test_openai_response_target_function_chaining.py index a9240bbd3c..7748e51e6f 100644 --- a/tests/unit/target/test_openai_response_target_function_chaining.py +++ b/tests/unit/target/test_openai_response_target_function_chaining.py @@ -21,12 +21,11 @@ @pytest.fixture def response_target(patch_central_database): """Create a test OpenAIResponseTarget.""" - target = OpenAIResponseTarget( + return OpenAIResponseTarget( model_name="gpt-4", endpoint="https://mock.azure.com", api_key="mock-key", ) - return target def create_mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock: diff --git a/tests/unit/target/test_prompt_shield_target.py b/tests/unit/target/test_prompt_shield_target.py index d29225161c..3ac3525f79 100644 --- a/tests/unit/target/test_prompt_shield_target.py +++ b/tests/unit/target/test_prompt_shield_target.py @@ -46,8 +46,7 @@ def sample_delineated_prompt_as_dict() -> dict: @pytest.fixture def sample_conversation_piece(sample_delineated_prompt_as_str: str) -> MessagePiece: - prp = MessagePiece(role="user", original_value=sample_delineated_prompt_as_str) - return prp + return MessagePiece(role="user", original_value=sample_delineated_prompt_as_str) def test_promptshield_init(promptshield_target: PromptShieldTarget):