Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions build_scripts/conditional_jb_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def main():
cwd=os.path.dirname(os.path.dirname(__file__)), # Repository root
)
return result.returncode
else:
print("RUN_LONG_PRECOMMIT not set: Skipping full Jupyter Book build (fast validation runs instead)")
return 0
print("RUN_LONG_PRECOMMIT not set: Skipping full Jupyter Book build (fast validation runs instead)")
return 0


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions build_scripts/prepare_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def copy_frontend_to_package(frontend_dist: Path, backend_frontend: Path) -> boo
if index_html.exists():
print("✓ Frontend successfully copied to package")
return True
else:
print("ERROR: index.html not found after copy")
return False
print("ERROR: index.html not found after copy")
return False


def main():
Expand Down
5 changes: 2 additions & 3 deletions build_scripts/validate_jupyter_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ def main():
print(f" • {error}", file=sys.stderr)
print("=" * 80, file=sys.stderr)
return 1
else:
print("\n[OK] All Jupyter Book validations passed!")
return 0
print("\n[OK] All Jupyter Book validations passed!")
return 0


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions doc/deployment/deploy_hf_model_aml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@
" # Generate a 5-char random alphanumeric string and append to '-'\n",
" random_suffix = \"-\" + \"\".join(random.choices(string.ascii_letters + string.digits, k=5))\n",
"\n",
" updated_name = f\"{base_name}{random_suffix}\"\n",
"\n",
" return updated_name"
" return f\"{base_name}{random_suffix}\""
]
},
{
Expand Down
4 changes: 1 addition & 3 deletions doc/deployment/deploy_hf_model_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def get_updated_endpoint_name(endpoint_name):
# Generate a 5-char random alphanumeric string and append to '-'
random_suffix = "-" + "".join(random.choices(string.ascii_letters + string.digits, k=5))

updated_name = f"{base_name}{random_suffix}"

return updated_name
return f"{base_name}{random_suffix}"


# %%
Expand Down
3 changes: 1 addition & 2 deletions doc/deployment/download_and_register_hf_model_aml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@
"\n",
" # Find the model with the maximum version number\n",
" max_version = max(models, key=lambda x: int(x.version)).version\n",
" model_max_version = str(int(max_version))\n",
" return model_max_version"
" return str(int(max_version))"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions doc/deployment/download_and_register_hf_model_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def get_max_model_version(models: list) -> str:

# Find the model with the maximum version number
max_version = max(models, key=lambda x: int(x.version)).version
model_max_version = str(int(max_version))
return model_max_version
return str(int(max_version))


# %%
Expand Down
10 changes: 3 additions & 7 deletions frontend/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def start_backend(initializers: list[str] | None = None):
cmd.extend(["--initializers"] + initializers)

# Start backend
backend = subprocess.Popen(cmd, env=env)

return backend
return subprocess.Popen(cmd, env=env)


def start_frontend():
Expand All @@ -134,9 +132,7 @@ def start_frontend():

# Start frontend process
npm_cmd = "npm.cmd" if is_windows() else "npm"
frontend = subprocess.Popen([npm_cmd, "run", "dev"])

return frontend
return subprocess.Popen([npm_cmd, "run", "dev"])


def start_servers():
Expand Down Expand Up @@ -199,7 +195,7 @@ def main():
if command == "stop":
stop_servers()
return
elif command == "restart":
if command == "restart":
stop_servers()
time.sleep(1)
elif command == "start":
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ select = [
"DOC", # https://docs.astral.sh/ruff/rules/#pydoclint-doc
"F401", # unused-import
"I", # isort
"RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret
"W", # https://docs.astral.sh/ruff/rules/#pycodestyle-w
]
ignore = [
Expand Down
7 changes: 2 additions & 5 deletions pyrit/analytics/text_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def is_match(self, *, target: str, text: str) -> bool:
text = text.strip()
if self._case_sensitive:
return target in text
else:
return target.lower() in text.lower()
return target.lower() in text.lower()


class ApproximateTextMatching(TextMatching):
Expand Down Expand Up @@ -141,9 +140,7 @@ def _calculate_ngram_overlap(self, *, target: str, text: str) -> float:
matching_ngrams = sum(int(ngram in text_str) for ngram in target_ngrams)

# Calculate proportion of matching n-grams
score = matching_ngrams / len(target_ngrams)

return score
return matching_ngrams / len(target_ngrams)

def get_overlap_score(self, *, target: str, text: str) -> float:
"""
Expand Down
14 changes: 5 additions & 9 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def get_azure_token_provider(scope: str) -> Callable[[], str]:
>>> token = token_provider() # Get current token
"""
try:
token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope)
return token_provider
return get_bearer_token_provider(DefaultAzureCredential(), scope)
except Exception as e:
logger.error(f"Failed to obtain token provider for '{scope}': {e}")
raise
Expand All @@ -246,8 +245,7 @@ def get_azure_async_token_provider(scope: str): # type: ignore[no-untyped-def]
>>> token = await token_provider() # Get current token (in async context)
"""
try:
token_provider = get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope)
return token_provider
return get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope)
except Exception as e:
logger.error(f"Failed to obtain async token provider for '{scope}': {e}")
raise
Expand Down Expand Up @@ -334,13 +332,12 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi
subscription=key,
region=region,
)
elif resource_id and region:
if resource_id and region:
return get_speech_config_from_default_azure_credential(
resource_id=resource_id,
region=region,
)
else:
raise ValueError("Insufficient information provided for Azure Speech service.")
raise ValueError("Insufficient information provided for Azure Speech service.")


def get_speech_config_from_default_azure_credential(resource_id: str, region: str) -> speechsdk.SpeechConfig:
Expand Down Expand Up @@ -370,11 +367,10 @@ def get_speech_config_from_default_azure_credential(resource_id: str, region: st
azure_auth = AzureAuth(token_scope=get_default_azure_scope(""))
token = azure_auth.get_token()
authorization_token = "aad#" + resource_id + "#" + token
speech_config = speechsdk.SpeechConfig(
return speechsdk.SpeechConfig(
auth_token=authorization_token,
region=region,
)
return speech_config
except Exception as e:
logger.error(f"Failed to get speech config for resource ID '{resource_id}' and region '{region}': {e}")
raise
4 changes: 1 addition & 3 deletions pyrit/auth/azure_storage_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> Use
delegation_key_start_time = datetime.now()
delegation_key_expiry_time = delegation_key_start_time + timedelta(days=1)

user_delegation_key = await blob_service_client.get_user_delegation_key(
return await blob_service_client.get_user_delegation_key(
key_start_time=delegation_key_start_time, key_expiry_time=delegation_key_expiry_time
)

return user_delegation_key

@staticmethod
async def get_sas_token(container_url: str) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion pyrit/auth/copilot_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def _get_cached_token_if_available_and_valid(self) -> Optional[dict[str, A
if not cached_user:
logger.info("No user associated with cached token. Token invalidated.")
return None
elif cached_user != self._username:
if cached_user != self._username:
logger.info(
f"Cached token is for different user (cached: {cached_user}, current: {self._username}). "
"Token invalidated."
Expand Down
51 changes: 22 additions & 29 deletions pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,44 +66,41 @@ def default(self, obj: Any) -> Any:
def get_embedding_layer(model: Any) -> Any:
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte
elif isinstance(model, LlamaForCausalLM):
if isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens
elif isinstance(model, GPTNeoXForCausalLM):
if isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in
elif isinstance(model, Phi3ForCausalLM):
if isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens
else:
raise ValueError(f"Unknown model type: {type(model)}")
raise ValueError(f"Unknown model type: {type(model)}")


def get_embedding_matrix(model: Any) -> Any:
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte.weight
elif isinstance(model, LlamaForCausalLM):
if isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens.weight
elif isinstance(model, GPTNeoXForCausalLM):
if isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in.weight # type: ignore[union-attr, unused-ignore]
elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
if isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
return model.model.embed_tokens.weight
elif isinstance(model, Phi3ForCausalLM):
if isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens.weight
else:
raise ValueError(f"Unknown model type: {type(model)}")
raise ValueError(f"Unknown model type: {type(model)}")


def get_embeddings(model: Any, input_ids: torch.Tensor) -> Any:
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte(input_ids).half()
elif isinstance(model, LlamaForCausalLM):
if isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens(input_ids)
elif isinstance(model, GPTNeoXForCausalLM):
if isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in(input_ids).half() # type: ignore[operator, unused-ignore]
elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
if isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
return model.model.embed_tokens(input_ids)
elif isinstance(model, Phi3ForCausalLM):
if isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens(input_ids)
else:
raise ValueError(f"Unknown model type: {type(model)}")
raise ValueError(f"Unknown model type: {type(model)}")


def get_nonascii_toks(tokenizer: Any, device: str = "cpu") -> torch.Tensor:
Expand Down Expand Up @@ -364,24 +361,21 @@ def logits(self, model: Any, test_controls: Any = None, return_ids: bool = False
del locs, test_ids
gc.collect()
return model(input_ids=ids, attention_mask=attn_mask).logits, ids
else:
del locs, test_ids
logits = model(input_ids=ids, attention_mask=attn_mask).logits
del ids
gc.collect()
return logits
del locs, test_ids
logits = model(input_ids=ids, attention_mask=attn_mask).logits
del ids
gc.collect()
return logits

def target_loss(self, logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
crit = nn.CrossEntropyLoss(reduction="none")
loss_slice = slice(self._target_slice.start - 1, self._target_slice.stop - 1)
loss = crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._target_slice])
return loss # type: ignore[no-any-return, unused-ignore]
return crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._target_slice]) # type: ignore[no-any-return]

def control_loss(self, logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
crit = nn.CrossEntropyLoss(reduction="none")
loss_slice = slice(self._control_slice.start - 1, self._control_slice.stop - 1)
loss = crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._control_slice])
return loss # type: ignore[no-any-return, unused-ignore]
return crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, self._control_slice]) # type: ignore[no-any-return]

@property
def assistant_str(self) -> Any:
Expand Down Expand Up @@ -527,8 +521,7 @@ def logits(self, model: Any, test_controls: Any = None, return_ids: bool = False
vals = [prompt.logits(model, test_controls, return_ids) for prompt in self._prompts]
if return_ids:
return [val[0] for val in vals], [val[1] for val in vals]
else:
return vals
return vals

def target_loss(self, logits: list[torch.Tensor], ids: list[torch.Tensor]) -> torch.Tensor:
return torch.cat(
Expand Down
3 changes: 1 addition & 2 deletions pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def sample_control(
new_token_val = torch.gather(
top_indices[new_token_pos], 1, torch.randint(0, topk, (batch_size, 1), device=grad.device)
)
new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)
return new_control_toks
return original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)


class GCGMultiPromptAttack(MultiPromptAttack):
Expand Down
31 changes: 15 additions & 16 deletions pyrit/auxiliary_attacks/gcg/experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,18 @@ def _create_attack(
mpa_batch_size=params.batch_size,
mpa_n_steps=params.n_steps,
)
else:
return IndividualPromptAttack(
train_goals,
train_targets,
workers,
control_init=params.control_init,
logfile=f"{params.result_prefix}_{timestamp}.json",
managers=managers,
test_goals=getattr(params, "test_goals", []),
test_targets=getattr(params, "test_targets", []),
test_workers=test_workers,
mpa_deterministic=params.gbda_deterministic,
mpa_lr=params.learning_rate,
mpa_batch_size=params.batch_size,
mpa_n_steps=params.n_steps,
)
return IndividualPromptAttack(
train_goals,
train_targets,
workers,
control_init=params.control_init,
logfile=f"{params.result_prefix}_{timestamp}.json",
managers=managers,
test_goals=getattr(params, "test_goals", []),
test_targets=getattr(params, "test_targets", []),
test_workers=test_workers,
mpa_deterministic=params.gbda_deterministic,
mpa_lr=params.learning_rate,
mpa_batch_size=params.batch_size,
mpa_n_steps=params.n_steps,
)
3 changes: 1 addition & 2 deletions pyrit/common/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def get_default_data_path(dir: str) -> pathlib.Path:
"""
if in_git_repo():
return pathlib.Path(PYRIT_PATH, "..", dir).resolve()
else:
return pathlib.Path(user_data_dir(dir, "pyrit")).resolve()
return pathlib.Path(user_data_dir(dir, "pyrit")).resolve()


def in_git_repo() -> bool:
Expand Down
3 changes: 1 addition & 2 deletions pyrit/common/question_answer_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ def construct_evaluation_prompt(entry: QuestionAnsweringEntry) -> str:
str: The constructed prompt string.
"""
available_choices = "\n".join(f"index={c.index}, value={c.text}" for c in entry.choices)
prompt = f"Question:\n{entry.question}\n\nChoices:\n{available_choices}"
return prompt
return f"Question:\n{entry.question}\n\nChoices:\n{available_choices}"
3 changes: 1 addition & 2 deletions pyrit/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def combine_list(list1: Union[str, List[str]], list2: Union[str, List[str]]) ->
list2 = [list2]

# Merge and keep only unique values
combined = list(set(list1 + list2))
return combined
return list(set(list1 + list2))


def get_random_indices(*, start: int, size: int, proportion: float) -> List[int]:
Expand Down
3 changes: 1 addition & 2 deletions pyrit/common/yaml_loadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,4 @@ def from_yaml_file(cls: Type[T], file: Union[Path | str]) -> T:
# otherwise, just instantiate directly with **yaml_data
if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")):
return cls.from_dict(yaml_data) # type: ignore
else:
return cls(**yaml_data)
return cls(**yaml_data)
Loading