-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix: avoid CUDA crash from repetition_penalty in Fun-ASR-Nano vLLM prompt-embeds mode (#2948) #2974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -204,6 +204,8 @@ def _process_one(self, audio_path, **kwargs): | |
| except ImportError: | ||
| from vllm.inputs.data import EmbedsPrompt | ||
|
|
||
| from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to PEP 8, imports should always be placed at the top of the file. Since References
|
||
|
|
||
| prompts = [] | ||
| for seg_audio in segment_audios: | ||
| seg_tensor = torch.from_numpy(seg_audio).float() | ||
|
|
@@ -219,7 +221,10 @@ def _process_one(self, audio_path, **kwargs): | |
| params = SamplingParams( | ||
| max_tokens=kwargs.get("max_new_tokens", 512), | ||
| temperature=0.0, | ||
| repetition_penalty=1.3, | ||
| # Prompt-embeds mode has no token IDs to penalize; see #2948. | ||
| repetition_penalty=resolve_repetition_penalty( | ||
| kwargs.get("repetition_penalty", 1.0) | ||
| ), | ||
| skip_special_tokens=True, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,8 +234,15 @@ def streaming_generate(self, audio_input, chunk_ms=None, rollback_chars=None, | |
| chunk_samples = int(self.sample_rate * chunk_ms / 1000) | ||
| num_chunks = (total_samples + chunk_samples - 1) // chunk_samples | ||
|
|
||
| params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature, | ||
| repetition_penalty=1.3, skip_special_tokens=True) | ||
| from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to PEP 8, imports should always be placed at the top of the file. Since References
|
||
|
|
||
| # Prompt-embeds mode has no token IDs to penalize; see #2948. | ||
| params = SamplingParams( | ||
| max_tokens=max_new_tokens, temperature=temperature, | ||
| repetition_penalty=resolve_repetition_penalty( | ||
| kwargs.get("repetition_penalty", 1.0) | ||
| ), | ||
| skip_special_tokens=True) | ||
|
|
||
| # Two-stage approach for long audio: | ||
| # Stage 1: batch first N chunks fresh (no prev_text) to find stable output | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,58 @@ | ||||||
| """Helpers shared by the Fun-ASR-Nano vLLM serving paths. | ||||||
|
|
||||||
| Kept dependency-free (standard library only) so it can be imported and unit | ||||||
| tested without a CUDA device or a vLLM installation. | ||||||
| """ | ||||||
|
|
||||||
| import logging | ||||||
|
|
||||||
| logger = logging.getLogger("funasr.fun_asr_nano.vllm") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a hardcoded string for the logger name (
Suggested change
|
||||||
|
|
||||||
| # A repetition penalty of 1.0 is the identity value, i.e. "no penalty". | ||||||
| NEUTRAL_REPETITION_PENALTY = 1.0 | ||||||
|
|
||||||
| # Warn only once per process so streaming/batch loops do not spam the log. | ||||||
| _warned_prompt_embeds = False | ||||||
|
|
||||||
|
|
||||||
| def resolve_repetition_penalty(repetition_penalty, *, prompt_embeds=True): | ||||||
| """Return a repetition penalty that is safe for the requested vLLM mode. | ||||||
|
|
||||||
| Fun-ASR-Nano feeds vLLM precomputed audio/text *embeddings* with | ||||||
| ``enable_prompt_embeds=True``. In that mode a request carries no prompt | ||||||
| token IDs. vLLM applies ``repetition_penalty`` by scattering over the | ||||||
| prompt's token IDs, so any value other than 1.0 indexes an empty token-id | ||||||
| tensor and aborts the engine with a CUDA | ||||||
| ``scatter gather kernel index out of bounds`` assertion (issue #2948). | ||||||
|
|
||||||
| When ``prompt_embeds`` is True we therefore force the penalty back to the | ||||||
| neutral value and warn once. With ``prompt_embeds=False`` (regular | ||||||
| token-prompt decoding) the requested value is passed through unchanged. | ||||||
|
|
||||||
| Args: | ||||||
| repetition_penalty: Penalty requested by the caller. ``None`` is | ||||||
| treated as "unset" and maps to the neutral value. | ||||||
| prompt_embeds: Whether the request runs in vLLM prompt-embeds mode. | ||||||
|
|
||||||
| Returns: | ||||||
| A repetition penalty that will not crash the engine. | ||||||
| """ | ||||||
| global _warned_prompt_embeds | ||||||
|
|
||||||
| if repetition_penalty is None: | ||||||
| return NEUTRAL_REPETITION_PENALTY | ||||||
|
|
||||||
| if prompt_embeds and repetition_penalty != NEUTRAL_REPETITION_PENALTY: | ||||||
| if not _warned_prompt_embeds: | ||||||
| logger.warning( | ||||||
| "repetition_penalty=%s is not supported in vLLM prompt-embeds " | ||||||
| "mode (no prompt token IDs to penalize) and would trigger a CUDA " | ||||||
| "scatter index-out-of-bounds crash; using repetition_penalty=%s " | ||||||
| "instead. See https://github.com/modelscope/FunASR/issues/2948.", | ||||||
| repetition_penalty, | ||||||
| NEUTRAL_REPETITION_PENALTY, | ||||||
| ) | ||||||
| _warned_prompt_embeds = True | ||||||
| return NEUTRAL_REPETITION_PENALTY | ||||||
|
|
||||||
| return repetition_penalty | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,74 @@ | ||||||||||||||||||||||||||||||||||||
| """Unit tests for Fun-ASR-Nano vLLM repetition-penalty handling. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Regression guard for issue #2948: a repetition penalty other than 1.0 is | ||||||||||||||||||||||||||||||||||||
| incompatible with vLLM prompt-embeds mode and aborts the engine with a CUDA | ||||||||||||||||||||||||||||||||||||
| "scatter gather index out of bounds" assertion. The serving paths must never | ||||||||||||||||||||||||||||||||||||
| forward such a value to ``SamplingParams`` while ``enable_prompt_embeds=True``. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| The helper is dependency-free, so these tests run without a GPU or vLLM. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||
| import unittest | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from funasr.models.fun_asr_nano import vllm_utils | ||||||||||||||||||||||||||||||||||||
| from funasr.models.fun_asr_nano.vllm_utils import ( | ||||||||||||||||||||||||||||||||||||
| NEUTRAL_REPETITION_PENALTY, | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class TestResolveRepetitionPenalty(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||
| def setUp(self): | ||||||||||||||||||||||||||||||||||||
| # Reset the once-per-process warning flag so each test is independent. | ||||||||||||||||||||||||||||||||||||
| vllm_utils._warned_prompt_embeds = False | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_neutral_value_passes_through(self): | ||||||||||||||||||||||||||||||||||||
| self.assertEqual(resolve_repetition_penalty(1.0), 1.0) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_none_maps_to_neutral(self): | ||||||||||||||||||||||||||||||||||||
| self.assertEqual( | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(None), NEUTRAL_REPETITION_PENALTY | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_nonneutral_is_clamped_in_prompt_embeds_mode(self): | ||||||||||||||||||||||||||||||||||||
| # The exact value that triggers the #2948 crash. | ||||||||||||||||||||||||||||||||||||
| self.assertEqual( | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.3, prompt_embeds=True), | ||||||||||||||||||||||||||||||||||||
| NEUTRAL_REPETITION_PENALTY, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_nonneutral_preserved_for_token_prompts(self): | ||||||||||||||||||||||||||||||||||||
| # Regular token-prompt decoding can safely apply the penalty. | ||||||||||||||||||||||||||||||||||||
| self.assertEqual( | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.3, prompt_embeds=False), 1.3 | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_warns_once_in_prompt_embeds_mode(self): | ||||||||||||||||||||||||||||||||||||
| with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.3, prompt_embeds=True) | ||||||||||||||||||||||||||||||||||||
| # Subsequent clamps must not emit additional warnings. | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.5, prompt_embeds=True) | ||||||||||||||||||||||||||||||||||||
| self.assertEqual(len(cm.records), 1) | ||||||||||||||||||||||||||||||||||||
| self.assertIn("2948", cm.output[0]) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def test_no_warning_when_value_is_safe(self): | ||||||||||||||||||||||||||||||||||||
| # Capture records directly (assertNoLogs is only available on 3.10+). | ||||||||||||||||||||||||||||||||||||
| records = [] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class _Collect(logging.Handler): | ||||||||||||||||||||||||||||||||||||
| def emit(self, record): | ||||||||||||||||||||||||||||||||||||
| records.append(record) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| handler = _Collect(level=logging.WARNING) | ||||||||||||||||||||||||||||||||||||
| vllm_utils.logger.addHandler(handler) | ||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.0, prompt_embeds=True) | ||||||||||||||||||||||||||||||||||||
| resolve_repetition_penalty(1.3, prompt_embeds=False) | ||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||
| vllm_utils.logger.removeHandler(handler) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+63
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To ensure the test is robust and independent of any external logging configuration (which might set the logger level to
Suggested change
|
||||||||||||||||||||||||||||||||||||
| self.assertEqual(records, []) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||
| unittest.main() | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to PEP 8, imports should always be placed at the top of the file. Since
vllm_utilsis a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call togenerate.References