From a3e64f58aa86451ce27b1a5e72c10732c2e0b58a Mon Sep 17 00:00:00 2001 From: yanjunqiAz <99990718+yanjunqiAz@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:18:25 -0400 Subject: [PATCH 1/3] Add CLAUDE.md and IMPROVEMENT_PLAN.md project docs --- CLAUDE.md | 101 +++++++++++++++++++ IMPROVEMENT_PLAN.md | 240 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 341 insertions(+) create mode 100644 CLAUDE.md create mode 100644 IMPROVEMENT_PLAN.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..ead1ed3e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,101 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +TextAttack (v0.3.10) is a Python framework for adversarial attacks, data augmentation, and model training in NLP. It provides a modular system where attacks are composed of four pluggable components: goal functions, constraints, transformations, and search methods. The project is maintained by UVA QData Lab. + +## Common Commands + +### Installation (dev mode) +```bash +pip install -e .[dev] +``` + +### Testing +```bash +make test # Run full test suite (pytest --dist=loadfile -n auto) +pytest tests -v # Verbose test run +pytest tests/test_augment_api.py # Run a single test file +pytest --lf # Re-run only last failed tests +``` + +### Formatting & Linting +```bash +make format # Auto-format with black, isort, docformatter +make lint # Check formatting (black --check, isort --check-only, flake8) +``` + +### Building Docs +```bash +make docs # Build HTML docs with Sphinx +make docs-auto # Hot-reload docs server on port 8765 +``` + +### CLI Usage +```bash +textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 100 +textattack augment --input-csv examples.csv --output-csv output.csv --input-column text --recipe embedding +textattack train --model-name-or-path lstm --dataset yelp_polarity --epochs 50 +textattack list attack-recipes +textattack peek-dataset --dataset-from-huggingface snli +``` + +## Architecture + +### Core Attack Pipeline (`textattack/attack.py`, `textattack/attacker.py`) + +An `Attack` is composed of exactly four components: +1. **GoalFunction** (`textattack/goal_functions/`) - Determines if an attack succeeded. Categories: `classification/` (untargeted, targeted), `text/` (BLEU, translation overlap), `custom/`. +2. **Constraints** (`textattack/constraints/`) - Filter invalid perturbations. Categories: `semantics/` (sentence encoders, word embeddings), `grammaticality/` (POS, language models, grammar tools), `overlap/` (edit distance, BLEU), `pre_transformation/` (restrict search space before transforming). +3. **Transformation** (`textattack/transformations/`) - Generate candidate perturbations. Types: `word_swaps/` (embedding, gradient, homoglyph, WordNet), `word_insertions/`, `word_merges/`, `sentence_transformations/`, `WordDeletion`, `CompositeTransformation`. +4. **SearchMethod** (`textattack/search_methods/`) - Traverse the perturbation space. Includes: `BeamSearch`, `GreedySearch`, `GreedyWordSwapWIR`, `GeneticAlgorithm`, `ParticleSwarmOptimization`, `DifferentialEvolution`. + +The `Attacker` class orchestrates running attacks on datasets with parallel processing, checkpointing, and logging. + +### Attack Recipes (`textattack/attack_recipes/`) + +Pre-built attack configurations from the literature (e.g., TextFooler, DeepWordBug, BAE, BERT-Attack, CLARE, CheckList, etc.). Each recipe subclasses `AttackRecipe` and implements a `build(model_wrapper)` classmethod that returns a configured `Attack` object. Includes multi-lingual recipes for French, Spanish, and Chinese. + +### Key Abstractions + +- **`AttackedText`** (`textattack/shared/attacked_text.py`) - Central text representation that maintains both token list and original text with punctuation. Used throughout the pipeline instead of raw strings. +- **`ModelWrapper`** (`textattack/models/wrappers/`) - Abstract interface for models. Implementations for PyTorch, HuggingFace, TensorFlow, sklearn. Models must accept string input and return predictions. +- **`Dataset`** (`textattack/datasets/`) - Iterable of `(input, output)` pairs. Supports HuggingFace datasets and custom files. +- **`Augmenter`** (`textattack/augmentation/`) - Uses transformations and constraints for data augmentation (not adversarial attacks). Built-in recipes: wordnet, embedding, charswap, eda, checklist, clare, back_trans. +- **`PromptAugmentationPipeline`** (`textattack/prompt_augmentation/`) - Augments prompts and generates LLM responses. +- **LLM Wrappers** (`textattack/llms/`) - Wrappers for using LLMs (HuggingFace, ChatGPT) with prompt augmentation. + +### CLI Commands (`textattack/commands/`) + +Entry point: `textattack/commands/textattack_cli.py`. Each command (attack, augment, train, eval-model, list, peek-dataset, benchmark-recipe, attack-resume) is a subclass of `TextAttackCommand` with `register_subcommand()` and `run()` methods. + +### Configuration + +- Version tracked in `docs/conf.py` (imported by `setup.py`) +- Cache directory: `~/.cache/textattack/` (override with `TA_CACHE_DIR` env var) +- Formatting: black (line length 88), isort (skip `__init__.py`), flake8 (ignores: E203, E266, E501, W503, D203) + +### CI Workflows (`.github/workflows/`) + +- `check-formatting.yml` - Runs `make lint` on Python 3.9 +- `run-pytest.yml` - Sets up Python 3.8/3.9 (pytest currently skipped in CI) +- `publish-to-pypi.yml` - PyPI publishing +- `make-docs.yml` - Documentation build +- `codeql-analysis.yml` - Security analysis + +### Test Structure + +Tests are in `tests/` organized by feature: +- `test_command_line/` - CLI command integration tests (attack, augment, train, eval, list, loggers) +- `test_constraints/` - Constraint unit tests +- `test_augment_api.py`, `test_transformations.py`, `test_attacked_text.py`, `test_tokenizers.py`, `test_word_embedding.py`, `test_metric_api.py`, `test_prompt_augmentation.py` +- `test_command_line/update_test_outputs.py` - Script to regenerate expected test outputs + +### Adding New Components + +- **Attack recipe**: Subclass `AttackRecipe` in `textattack/attack_recipes/`, implement `build(model_wrapper)`, add import to `__init__.py`, add doc reference in `docs/attack_recipes.rst`. +- **Transformation**: Subclass `Transformation` in appropriate subfolder under `textattack/transformations/`. +- **Constraint**: Subclass `Constraint` or `PreTransformationConstraint` in appropriate subfolder under `textattack/constraints/`. +- **Search method**: Subclass `SearchMethod` in `textattack/search_methods/`. diff --git a/IMPROVEMENT_PLAN.md b/IMPROVEMENT_PLAN.md new file mode 100644 index 00000000..c988f604 --- /dev/null +++ b/IMPROVEMENT_PLAN.md @@ -0,0 +1,240 @@ +# TextAttack Codebase Improvement Plan + +A prioritized, holistic plan for modernizing and hardening the TextAttack codebase. Each item includes rationale, affected files, and suggested approach. + +**Guiding principle:** Infrastructure, tooling, and non-functional improvements come first so that functional changes benefit from better CI, packaging, and code quality foundations. + +--- + +## Priority 1 — Critical (Infrastructure & CI) + +### 1.1 Re-enable tests in CI + +**Why:** The pytest step in CI is completely commented out (`echo "skipping tests!"` in `run-pytest.yml` line 55). This means every merged PR bypasses the test suite. Without CI tests, regressions accumulate silently, and contributors have no automated safety net. This must be fixed first — all subsequent changes need CI to validate them. + +**Affected file:** `.github/workflows/run-pytest.yml` (lines 54–56) + +**Suggested approach:** Uncomment the `pytest tests -v` line. If tests are failing and that's why they were disabled, fix the failing tests first — disabling CI is not a sustainable workaround. + +### 1.2 Update CI infrastructure + +**Why:** All GitHub Actions workflows use `actions/checkout@v2` and `actions/setup-python@v2`, which are deprecated and will eventually stop working. The CodeQL workflow uses `v1` actions. The Python version matrix only covers 3.8 and 3.9 — Python 3.8 reached end-of-life in October 2024, and 3.10–3.12 are untested. + +**Affected files:** All `.github/workflows/*.yml` files (5 files) + +**Suggested approach:** +- Update to `actions/checkout@v4`, `actions/setup-python@v5`, `github/codeql-action/*@v3`. +- Expand Python matrix to `[3.9, 3.10, 3.11, 3.12]`. +- Drop 3.8 from the matrix and update `python_requires` in setup metadata. +- Replace `python setup.py sdist bdist_wheel` with `python -m build` in publish workflow. + +### 1.3 Update pinned dev tool versions + +**Why:** Test extras pin `black==20.8b1` (from August 2020) and `isort==5.6.4` (from 2020). These versions are incompatible with Python 3.10+ and miss years of bug fixes and formatting improvements. Contributors on modern Python cannot install the dev extras. + +**Affected file:** `setup.py` (lines 20–27) + +**Suggested approach:** Update to current stable versions (`black>=23.0`, `isort>=5.12`). Consider using `pre-commit` to manage formatting tool versions consistently across contributors. + +--- + +## Priority 2 — High (Packaging & Dependencies) + +### 2.1 Modernize packaging: add `pyproject.toml` + +**Why:** The project relies solely on `setup.py`, which is the legacy packaging approach. PEP 517/518 (`pyproject.toml`) is now the standard. The current setup also has fragile patterns: version is imported from `docs/conf.py` at build time (cross-directory import that can break in isolated builds), and `requirements.txt` is read via `open().readlines()` without stripping whitespace. + +**Affected files:** `setup.py`, `docs/conf.py`, `textattack/__init__.py` + +**Suggested approach:** +- Create `pyproject.toml` with build-system metadata, dependencies, and project metadata. +- Move the version string to `textattack/__init__.py` as `__version__` (users expect `textattack.__version__` to work — it currently doesn't exist). +- Replace `setup.py` with a minimal shim or remove it entirely. + +### 2.2 Fix dependency version constraints + +**Why:** 15 of 22 runtime dependencies in `requirements.txt` have no version constraint at all (e.g., `flair`, `nltk`, `language_tool_python`). This means a new release of any of these can silently break TextAttack. The remaining dependencies use only `>=` lower bounds with no upper bounds, which provides minimal protection. + +**Affected file:** `requirements.txt` + +**Suggested approach:** Add compatible-release constraints (`~=`) or upper bounds for all dependencies. At minimum, pin major versions (e.g., `flair>=0.12,<1.0`). Run `pip freeze` on a known-good environment to establish baseline versions. + +--- + +## Priority 3 — Medium (Non-Functional Code Quality) + +### 3.1 Externalize the 10,669-line `data.py` file + +**Why:** `textattack/shared/data.py` is a single 10,669-line file containing only hardcoded named entity lists (country names, person names, etc.). This makes git diffs noisy, IDE indexing slow, and the module hard to navigate. It inflates the package size unnecessarily as Python source. + +**Affected file:** `textattack/shared/data.py` + +**Suggested approach:** Move data into JSON or text files under a `textattack/data/` directory. Load them lazily at first use. This also makes it easier for users to customize or extend the lists. + +### 3.2 Replace deprecated `logger.warn()` with `logger.warning()` + +**Why:** `logger.warn()` has been deprecated since Python 3.2 and may be removed in a future version. It already emits deprecation warnings in some environments. + +**Affected files:** +- `textattack/attacker.py` (lines 94, 182, 353) +- `textattack/trainer.py` (line 116) +- `textattack/shared/validators.py` (lines 59, 74, 83) +- `textattack/shared/utils/misc.py` (line 68) + +**Suggested approach:** Global find-and-replace of `.warn(` with `.warning(` in these files. + +### 3.3 Add type hints to core classes + +**Why:** The core classes (`Attack`, `Attacker`, `GoalFunction`, `SearchMethod`) have essentially zero return type hints. This makes IDE autocompletion unreliable, prevents static analysis from catching bugs, and forces new contributors to read implementation to understand interfaces. `AttackedText` is partially typed (~80%) but inconsistent. + +**Affected files:** +- `textattack/attack.py` — 16 methods, 0 return type hints +- `textattack/attacker.py` — 11 methods, 0 return type hints +- `textattack/goal_functions/goal_function.py` — 18 methods, 0 return type hints +- `textattack/search_methods/search_method.py` — abstract class, no return types + +**Suggested approach:** Add return type annotations to all public methods in these four files first. Use `-> None`, `-> List[AttackedText]`, `-> AttackResult`, etc. This can be done incrementally without breaking changes. + +### 3.4 Replace star imports with explicit imports + +**Why:** Several `__init__.py` files use `from .module import *`, which makes it unclear what names are exported, can cause naming collisions, and breaks static analysis tools. + +**Affected files:** +- `textattack/shared/utils/__init__.py` (lines 1–5) +- `textattack/goal_functions/__init__.py` (lines 11–13) +- `textattack/transformations/__init__.py` (lines 11–14) + +**Suggested approach:** Replace star imports with explicit name lists. If maintaining `__all__` in submodules, that's acceptable — but the importing modules should still list names explicitly. + +### 3.5 Clean up `.gitignore` + +**Why:** The `.gitignore` contains a suspicious line `textattack/=22.3.0` (line 52) that looks like leftover state from pip output, not a valid ignore pattern. + +**Suggested approach:** Remove the invalid line. Audit remaining entries for completeness (add `.env` if missing). + +### 3.6 Add `tests/conftest.py` and expand test coverage + +**Why:** There is no shared test infrastructure (`conftest.py`). Core classes `Attack`, `Attacker`, `GoalFunction`, and `SearchMethod` have no dedicated unit tests. There's a TODO in `test_attacked_text.py` for missing `align_words_with_tokens` tests. + +**Suggested approach:** Create `tests/conftest.py` with shared fixtures (mock models, sample texts, etc.). Add unit tests for core classes. Prioritize testing the attack pipeline and search methods. + +--- + +## Priority 4 — High (Functional Fixes — Security & Correctness) + +These items change runtime behavior. They are ordered after infrastructure so that CI, packaging, and tests are in place to validate them. + +### 4.1 Replace all `eval()` calls with a safe registry/factory pattern + +**Why:** The codebase uses `eval()` extensively to instantiate components from user-supplied strings (attack recipes, transformations, goal functions, constraints, search methods). While inputs are partially validated against predefined dictionaries, `eval()` remains an inherent code-injection vector — especially dangerous in a library that accepts CLI arguments. Any future change that loosens the validation or introduces a new code path could expose users to arbitrary code execution. + +**Affected files:** +- `textattack/attack_args.py` (lines 623–752) — transformations, goal functions, constraints, search methods, recipes +- `textattack/model_args.py` (line 285) — model class instantiation +- `textattack/dataset_args.py` (line 243) — dataset instantiation +- `textattack/training_args.py` (line 589) — attack recipe instantiation +- `textattack/commands/augment_command.py` (lines 36, 84, 182) — augmentation recipes + +**Suggested approach:** Introduce a registry dict mapping string names to classes (e.g., `TRANSFORMATION_REGISTRY = {"word-swap-embedding": WordSwapEmbedding, ...}`). Use `getattr()` on known modules as a fallback. This is safer, faster, and easier to debug than `eval()`. + +### 4.2 Fix the `update_attack_args()` bug + +**Why:** This is a silent logic bug — the method appears to work but never actually updates the intended attribute. It always writes to a literal attribute named `k` instead of the dynamic key. + +**Affected file:** `textattack/attacker.py` (line 460) + +```python +# Current (broken): +self.attack_args.k = kwargs[k] + +# Fix: +setattr(self.attack_args, k, kwargs[k]) +``` + +**Why necessary:** Any code calling `attacker.update_attack_args(num_examples=100)` silently fails. This is a correctness bug that could cause wrong experimental results. + +### 4.3 Replace `assert` with proper exceptions for input validation + +**Why:** Python's `assert` statements are removed when running with `-O` (optimize) flag. Using them for input validation means all runtime checks silently vanish in optimized mode. This is particularly dangerous for a library where users may run in optimized mode for performance. + +**Affected files:** +- `textattack/attack.py` (lines 93–108) — validates constructor arguments +- `textattack/attacker.py` (lines 70–80) — validates attack args +- `textattack/attack_args.py` (lines 230–246) — validates configuration + +**Suggested approach:** Replace `assert condition, message` with `if not condition: raise TypeError(message)` or `ValueError(message)` as appropriate. + +### 4.4 Fix error handling anti-patterns + +**Why:** Several error handling patterns reduce debuggability and correctness: +- `except Exception as e: raise e` (attacker.py:170) — destroys the original traceback by re-raising via variable instead of bare `raise` +- `logging.disable()` without arguments (attacker.py:569) — globally disables ALL logging for the entire process, not just TextAttack +- `torch.cuda.empty_cache()` called without `torch.cuda.is_available()` guard — can fail on CPU-only systems + +**Suggested approach:** +- Change `raise e` to `raise` to preserve traceback +- Replace `logging.disable()` with `logger.setLevel(logging.CRITICAL)` for module-scoped control +- Add `if torch.cuda.is_available():` guard before CUDA calls + +### 4.5 Eliminate module-level side effects + +**Why:** Several modules execute side effects at import time: downloading data, calling `torch.cuda.empty_cache()`, setting environment variables, and importing heavy optional dependencies. This slows down `import textattack`, causes failures when optional deps are missing, and makes testing difficult because imports are no longer pure. + +**Affected files:** +- `textattack/shared/utils/install.py` (lines 203–210) — runs `_post_install_if_needed()` on import, which downloads NLTK data and does network I/O +- `textattack/shared/utils/strings.py` (lines 4–5) — top-level `import flair; import jieba` (should be lazy) +- `textattack/models/wrappers/huggingface_model_wrapper.py` (line 15) — `torch.cuda.empty_cache()` at module level +- `textattack/models/wrappers/pytorch_model_wrapper.py` (line 13) — same issue + +**Suggested approach:** Defer all side effects to first use. Use the `LazyLoader` pattern (already present in the codebase) for optional dependencies. Move CUDA cache clearing into method bodies. Gate network downloads behind explicit function calls. + +### 4.6 Fix thread-safety issue in prompt augmentation + +**Why:** `textattack/prompt_augmentation/prompt_augmentation_pipeline.py` (lines 31–41) mutates a shared augmenter's constraint list by appending a constraint, running augmentation, then popping it off. If an exception occurs between the append and pop, the constraint list is left in a corrupted state. This is also not thread-safe. + +**Suggested approach:** Create a copy of the constraints list or pass constraints as a parameter rather than mutating shared state. + +### 4.7 Use safer serialization where possible + +**Why:** Multiple files use `pickle.load()` to deserialize data downloaded from S3 or user-provided checkpoints. Pickle can execute arbitrary code during deserialization. + +**Affected files:** +- `textattack/shared/checkpoint.py` (lines 221, 226) +- `textattack/shared/word_embeddings.py` (lines 296–298) +- `textattack/transformations/word_swaps/word_swap_hownet.py` (line 30) + +**Suggested approach:** For internally-produced data (embeddings, candidate banks), migrate to safer formats (NumPy `.npy`, JSON, or `safetensors`). For checkpoints, add a warning in documentation about only loading trusted checkpoints. This is a longer-term migration. + +--- + +## Priority 5 — Low (New Features & Long-term Debt) + +### 5.1 Expand LLM integration + +**Why:** The `textattack/llms/` module contains only two thin wrappers (`ChatGPTWrapper`, `HuggingFaceLLMWrapper`). The ChatGPT wrapper has no retry logic, timeout handling, rate limiting, or error handling for missing API keys. These wrappers are not integrated into the main attack pipeline or documented. + +**Suggested approach:** Add proper error handling and retry logic to existing wrappers. Integrate LLM wrappers into the model wrapper hierarchy so they can be used with existing attacks. Document usage in the README and examples. + +### 5.2 Resolve accumulated TODOs + +**Why:** There are 13+ TODO/FIXME/HACK comments scattered across the codebase representing unresolved technical debt. Some are non-trivial bugs: +- `trainer.py:227` — TODO about ground truth manipulation bug +- `particle_swarm_optimization.py:67` — TODO about slow memory buildup +- `word_embedding_distance.py:69` — FIXME: index sometimes larger than tokens-1 +- `attacked_text.py:460` — TODO about undefined punctuation behavior + +**Suggested approach:** Triage each TODO into a GitHub issue with severity label. Fix the bug-class TODOs (trainer, PSO memory, embedding index) as part of Priority 4 work. Convert informational TODOs into GitHub issues and remove the comments. + +--- + +## Summary + +| Priority | Items | Theme | +|----------|-------|-------| +| **P1 — Critical** | 1.1–1.3 | CI re-enablement, CI modernization, dev tooling | +| **P2 — High** | 2.1–2.2 | Packaging modernization, dependency safety | +| **P3 — Medium** | 3.1–3.6 | Non-functional code quality, type hints, tests | +| **P4 — High** | 4.1–4.7 | Functional fixes: security, correctness, runtime behavior | +| **P5 — Low** | 5.1–5.2 | New features, tech debt cleanup | + +**Recommended execution order:** Start with P1 (CI & tooling) so all subsequent changes are validated automatically. Then P2 (packaging & deps) to stabilize the build. P3 (non-functional quality) can proceed in parallel. P4 (functional changes) comes after CI and tests are solid, ensuring behavioral changes are well-tested. P5 items are opportunistic or good first-contributor issues. From 95cf68ea58590c667ebce05396f26359655573fc Mon Sep 17 00:00:00 2001 From: yanjunqiAz <99990718+yanjunqiAz@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:36:48 -0400 Subject: [PATCH 2/3] Fix 7 test failures: compatibility bugs and missing skip markers Code fixes: - Replace removed transformers.optimization.AdamW with torch.optim.AdamW in trainer.py (removed in transformers>=4.x) - Use AutoTokenizer/AutoModelForMaskedLM instead of BertTokenizer/BertForMaskedLM in ChineseWordSwapMaskedLM, since xlm-roberta-base requires its own tokenizer - Fix hardcoded CUDA device in ChineseWordSwapMaskedLM to auto-detect device Test fixes: - Update stale expected output for list_augmentation_recipes to include BackTranscriptionAugmenter - Add pytest.skip for tests requiring tensorflow_hub when not installed (interactive_mode, adv_metrics attack tests, train test) - Add pytest.skipif for test_embedding_gensim when gensim not installed - Replace deprecated gensim Word2VecKeyedVectors API with KeyedVectors --- tests/sample_outputs/list_augmentation_recipes.txt | 1 + tests/test_command_line/test_attack.py | 6 ++++++ tests/test_command_line/test_train.py | 5 +++++ tests/test_word_embedding.py | 11 +++++++---- textattack/trainer.py | 2 +- .../chn_transformations/chinese_word_swap_masked.py | 12 +++++++----- 6 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/sample_outputs/list_augmentation_recipes.txt b/tests/sample_outputs/list_augmentation_recipes.txt index e84f4a6f..3078f4bc 100644 --- a/tests/sample_outputs/list_augmentation_recipes.txt +++ b/tests/sample_outputs/list_augmentation_recipes.txt @@ -1,4 +1,5 @@ back_trans (textattack.augmentation.BackTranslationAugmenter) +back_transcription (textattack.augmentation.BackTranscriptionAugmenter) charswap (textattack.augmentation.CharSwapAugmenter) checklist (textattack.augmentation.CheckListAugmenter) clare (textattack.augmentation.CLAREAugmenter) diff --git a/tests/test_command_line/test_attack.py b/tests/test_command_line/test_attack.py index eaaa9310..2cf665d9 100644 --- a/tests/test_command_line/test_attack.py +++ b/tests/test_command_line/test_attack.py @@ -1,9 +1,12 @@ +import importlib import pdb import re from helpers import run_command_and_get_result import pytest +_tensorflow_hub_available = importlib.util.find_spec("tensorflow_hub") is not None + DEBUG = False """Attack command-line tests in the format (name, args, sample_output_file)""" @@ -171,6 +174,9 @@ @pytest.mark.slow def test_command_line_attack(name, command, sample_output_file): """Runs attack tests and compares their outputs to a reference file.""" + _tf_hub_tests = {"interactive_mode", "attack_from_transformers_adv_metrics", "run_attack_hotflip_lstm_mr_4_adv_metrics"} + if name in _tf_hub_tests and not _tensorflow_hub_available: + pytest.skip("tensorflow_hub is not installed") # read in file and create regex desired_output = open(sample_output_file, "r").read().strip() print("desired_output.encoded =>", desired_output.encode()) diff --git a/tests/test_command_line/test_train.py b/tests/test_command_line/test_train.py index 34809e13..cb9ad6bc 100644 --- a/tests/test_command_line/test_train.py +++ b/tests/test_command_line/test_train.py @@ -1,9 +1,14 @@ +import importlib import os import re from helpers import run_command_and_get_result +import pytest +_tensorflow_hub_available = importlib.util.find_spec("tensorflow_hub") is not None + +@pytest.mark.skipif(not _tensorflow_hub_available, reason="tensorflow_hub is not installed") def test_train_tiny(): command = "textattack train --model distilbert-base-uncased --attack textfooler --dataset rotten_tomatoes --model-max-length 64 --num-epochs 1 --num-clean-epochs 0 --num-train-adv-examples 2" diff --git a/tests/test_word_embedding.py b/tests/test_word_embedding.py index 4772c27d..d781ff0a 100644 --- a/tests/test_word_embedding.py +++ b/tests/test_word_embedding.py @@ -1,3 +1,4 @@ +import importlib import os import numpy as np @@ -5,6 +6,8 @@ from textattack.shared import GensimWordEmbedding, WordEmbedding +_gensim_available = importlib.util.find_spec("gensim") is not None + def test_embedding_paragramcf(): word_embedding = WordEmbedding.counterfitted_GLOVE_embedding() @@ -13,6 +16,7 @@ def test_embedding_paragramcf(): assert word_embedding[10**9] is None +@pytest.mark.skipif(not _gensim_available, reason="gensim is not installed") def test_embedding_gensim(): # download a trained word2vec model from textattack.shared.utils import LazyLoader @@ -30,10 +34,9 @@ def test_embedding_gensim(): ) f.close() - gensim = LazyLoader("gensim", globals(), "gensim") - keyed_vectors = ( - gensim.models.keyedvectors.Word2VecKeyedVectors.load_word2vec_format(path) - ) + from gensim.models import KeyedVectors + + keyed_vectors = KeyedVectors.load_word2vec_format(path) word_embedding = GensimWordEmbedding(keyed_vectors) assert pytest.approx(word_embedding[0][0]) == 1 assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2) diff --git a/textattack/trainer.py b/textattack/trainer.py index 7569dd5d..32f32e59 100644 --- a/textattack/trainer.py +++ b/textattack/trainer.py @@ -361,7 +361,7 @@ def get_optimizer_and_scheduler(self, model, num_training_steps): }, ] - optimizer = transformers.optimization.AdamW( + optimizer = torch.optim.AdamW( optimizer_grouped_parameters, lr=self.training_args.learning_rate ) if isinstance(self.training_args.num_warmup_steps, float): diff --git a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py index 4e12b41f..72370f11 100644 --- a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py +++ b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py @@ -13,11 +13,13 @@ class ChineseWordSwapMaskedLM(WordSwap): model.""" def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs): - from transformers import BertForMaskedLM, BertTokenizer + from transformers import AutoModelForMaskedLM, AutoTokenizer - self.tt = BertTokenizer.from_pretrained(model) - self.mm = BertForMaskedLM.from_pretrained(model) - self.mm.to("cuda") + self.tt = AutoTokenizer.from_pretrained(model) + self.mm = AutoModelForMaskedLM.from_pretrained(model) + device = "cuda" if torch.cuda.is_available() else "cpu" + self.mm.to(device) + self._device = device super().__init__(**kwargs) def get_replacement_words(self, current_text, indice_to_modify): @@ -26,7 +28,7 @@ def get_replacement_words(self, current_text, indice_to_modify): ) # 修改前,xlmrberta的模型 tokens = self.tt.tokenize(masked_text.text) input_ids = self.tt.convert_tokens_to_ids(tokens) - input_tensor = torch.tensor([input_ids]).to("cuda") + input_tensor = torch.tensor([input_ids]).to(self._device) with torch.no_grad(): outputs = self.mm(input_tensor) predictions = outputs.logits From e2c467df01e3f0243c9eb6fb166a42771effd67e Mon Sep 17 00:00:00 2001 From: yanjunqiAz <99990718+yanjunqiAz@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:55:49 -0400 Subject: [PATCH 3/3] Fix CI: use [test] extras instead of [dev] to avoid visdom build failure The [dev] extras include visdom (via [optional]), which fails to build in pip's isolated build environments because visdom's setup.py imports pkg_resources from setuptools, which is not available in isolated builds with newer pip versions. The formatting workflow only needs linting tools and the pytest workflow only needs test tools, so [test] extras are sufficient for both. --- .github/workflows/check-formatting.yml | 2 +- .github/workflows/run-pytest.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check-formatting.yml b/.github/workflows/check-formatting.yml index 3151fa51..226261f4 100644 --- a/.github/workflows/check-formatting.yml +++ b/.github/workflows/check-formatting.yml @@ -27,7 +27,7 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 - pip install -e .[dev] + pip install -e .[test] pip install black flake8 isort --upgrade # Testing packages - name: Check code format with black and isort run: | diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml index c172d0e2..a6435144 100644 --- a/.github/workflows/run-pytest.yml +++ b/.github/workflows/run-pytest.yml @@ -27,9 +27,9 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel pip install pytest pytest-xdist # Testing packages - pip uninstall textattack --yes # Remove TA if it's already installed + pip uninstall textattack --yes # Remove TA if it's already installed python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 - pip install -e .[dev] + pip install -e .[test] pip freeze - name: Free disk space run: |