diff --git a/funasr/bin/_server_app.py b/funasr/bin/_server_app.py index feefcceb8..17c9ba676 100644 --- a/funasr/bin/_server_app.py +++ b/funasr/bin/_server_app.py @@ -132,8 +132,10 @@ def _process_vllm(audio_data, sr, language=None, hotwords=None, use_spk=False): if not seg_audios: return {"text": "", "segments": [], "duration": len(audio_data)/sr} - # vLLM generate with repetition_penalty - gen_kwargs = {"max_new_tokens": 500, "repetition_penalty": 1.3} + # repetition_penalty is left at the neutral 1.0: the Fun-ASR-Nano vLLM + # engine runs in prompt-embeds mode, where any other value crashes the + # CUDA kernel (see issue #2948 and fun_asr_nano.vllm_utils). + gen_kwargs = {"max_new_tokens": 500, "repetition_penalty": 1.0} if language: gen_kwargs["language"] = language if hotwords: diff --git a/funasr/models/fun_asr_nano/inference_vllm.py b/funasr/models/fun_asr_nano/inference_vllm.py index e6cb95aeb..f83df875b 100644 --- a/funasr/models/fun_asr_nano/inference_vllm.py +++ b/funasr/models/fun_asr_nano/inference_vllm.py @@ -527,6 +527,8 @@ def generate( except ImportError: from vllm.inputs.data import EmbedsPrompt + from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty + if isinstance(inputs, (str, np.ndarray, torch.Tensor)): inputs = [inputs] @@ -535,7 +537,8 @@ def generate( temperature=temperature, top_p=top_p, top_k=top_k if top_k > 0 else -1, - repetition_penalty=repetition_penalty, + # Prompt-embeds mode has no token IDs to penalize; see #2948. + repetition_penalty=resolve_repetition_penalty(repetition_penalty), skip_special_tokens=True, ) diff --git a/funasr/models/fun_asr_nano/inference_vllm_pipeline.py b/funasr/models/fun_asr_nano/inference_vllm_pipeline.py index bef4cb7a2..49241bc5b 100644 --- a/funasr/models/fun_asr_nano/inference_vllm_pipeline.py +++ b/funasr/models/fun_asr_nano/inference_vllm_pipeline.py @@ -204,6 +204,8 @@ def _process_one(self, audio_path, **kwargs): except ImportError: from vllm.inputs.data import EmbedsPrompt + from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty + prompts = [] for seg_audio in segment_audios: seg_tensor = torch.from_numpy(seg_audio).float() @@ -219,7 +221,10 @@ def _process_one(self, audio_path, **kwargs): params = SamplingParams( max_tokens=kwargs.get("max_new_tokens", 512), temperature=0.0, - repetition_penalty=1.3, + # Prompt-embeds mode has no token IDs to penalize; see #2948. + repetition_penalty=resolve_repetition_penalty( + kwargs.get("repetition_penalty", 1.0) + ), skip_special_tokens=True, ) diff --git a/funasr/models/fun_asr_nano/inference_vllm_streaming.py b/funasr/models/fun_asr_nano/inference_vllm_streaming.py index 4ff1f9ccf..b90689d55 100644 --- a/funasr/models/fun_asr_nano/inference_vllm_streaming.py +++ b/funasr/models/fun_asr_nano/inference_vllm_streaming.py @@ -234,8 +234,15 @@ def streaming_generate(self, audio_input, chunk_ms=None, rollback_chars=None, chunk_samples = int(self.sample_rate * chunk_ms / 1000) num_chunks = (total_samples + chunk_samples - 1) // chunk_samples - params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature, - repetition_penalty=1.3, skip_special_tokens=True) + from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty + + # Prompt-embeds mode has no token IDs to penalize; see #2948. + params = SamplingParams( + max_tokens=max_new_tokens, temperature=temperature, + repetition_penalty=resolve_repetition_penalty( + kwargs.get("repetition_penalty", 1.0) + ), + skip_special_tokens=True) # Two-stage approach for long audio: # Stage 1: batch first N chunks fresh (no prev_text) to find stable output diff --git a/funasr/models/fun_asr_nano/vllm_utils.py b/funasr/models/fun_asr_nano/vllm_utils.py new file mode 100644 index 000000000..8aedaa318 --- /dev/null +++ b/funasr/models/fun_asr_nano/vllm_utils.py @@ -0,0 +1,58 @@ +"""Helpers shared by the Fun-ASR-Nano vLLM serving paths. + +Kept dependency-free (standard library only) so it can be imported and unit +tested without a CUDA device or a vLLM installation. +""" + +import logging + +logger = logging.getLogger("funasr.fun_asr_nano.vllm") + +# A repetition penalty of 1.0 is the identity value, i.e. "no penalty". +NEUTRAL_REPETITION_PENALTY = 1.0 + +# Warn only once per process so streaming/batch loops do not spam the log. +_warned_prompt_embeds = False + + +def resolve_repetition_penalty(repetition_penalty, *, prompt_embeds=True): + """Return a repetition penalty that is safe for the requested vLLM mode. + + Fun-ASR-Nano feeds vLLM precomputed audio/text *embeddings* with + ``enable_prompt_embeds=True``. In that mode a request carries no prompt + token IDs. vLLM applies ``repetition_penalty`` by scattering over the + prompt's token IDs, so any value other than 1.0 indexes an empty token-id + tensor and aborts the engine with a CUDA + ``scatter gather kernel index out of bounds`` assertion (issue #2948). + + When ``prompt_embeds`` is True we therefore force the penalty back to the + neutral value and warn once. With ``prompt_embeds=False`` (regular + token-prompt decoding) the requested value is passed through unchanged. + + Args: + repetition_penalty: Penalty requested by the caller. ``None`` is + treated as "unset" and maps to the neutral value. + prompt_embeds: Whether the request runs in vLLM prompt-embeds mode. + + Returns: + A repetition penalty that will not crash the engine. + """ + global _warned_prompt_embeds + + if repetition_penalty is None: + return NEUTRAL_REPETITION_PENALTY + + if prompt_embeds and repetition_penalty != NEUTRAL_REPETITION_PENALTY: + if not _warned_prompt_embeds: + logger.warning( + "repetition_penalty=%s is not supported in vLLM prompt-embeds " + "mode (no prompt token IDs to penalize) and would trigger a CUDA " + "scatter index-out-of-bounds crash; using repetition_penalty=%s " + "instead. See https://github.com/modelscope/FunASR/issues/2948.", + repetition_penalty, + NEUTRAL_REPETITION_PENALTY, + ) + _warned_prompt_embeds = True + return NEUTRAL_REPETITION_PENALTY + + return repetition_penalty diff --git a/tests/test_fun_asr_nano_repetition_penalty.py b/tests/test_fun_asr_nano_repetition_penalty.py new file mode 100644 index 000000000..7b6cf8782 --- /dev/null +++ b/tests/test_fun_asr_nano_repetition_penalty.py @@ -0,0 +1,74 @@ +"""Unit tests for Fun-ASR-Nano vLLM repetition-penalty handling. + +Regression guard for issue #2948: a repetition penalty other than 1.0 is +incompatible with vLLM prompt-embeds mode and aborts the engine with a CUDA +"scatter gather index out of bounds" assertion. The serving paths must never +forward such a value to ``SamplingParams`` while ``enable_prompt_embeds=True``. + +The helper is dependency-free, so these tests run without a GPU or vLLM. +""" + +import logging +import unittest + +from funasr.models.fun_asr_nano import vllm_utils +from funasr.models.fun_asr_nano.vllm_utils import ( + NEUTRAL_REPETITION_PENALTY, + resolve_repetition_penalty, +) + + +class TestResolveRepetitionPenalty(unittest.TestCase): + def setUp(self): + # Reset the once-per-process warning flag so each test is independent. + vllm_utils._warned_prompt_embeds = False + + def test_neutral_value_passes_through(self): + self.assertEqual(resolve_repetition_penalty(1.0), 1.0) + + def test_none_maps_to_neutral(self): + self.assertEqual( + resolve_repetition_penalty(None), NEUTRAL_REPETITION_PENALTY + ) + + def test_nonneutral_is_clamped_in_prompt_embeds_mode(self): + # The exact value that triggers the #2948 crash. + self.assertEqual( + resolve_repetition_penalty(1.3, prompt_embeds=True), + NEUTRAL_REPETITION_PENALTY, + ) + + def test_nonneutral_preserved_for_token_prompts(self): + # Regular token-prompt decoding can safely apply the penalty. + self.assertEqual( + resolve_repetition_penalty(1.3, prompt_embeds=False), 1.3 + ) + + def test_warns_once_in_prompt_embeds_mode(self): + with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: + resolve_repetition_penalty(1.3, prompt_embeds=True) + # Subsequent clamps must not emit additional warnings. + resolve_repetition_penalty(1.5, prompt_embeds=True) + self.assertEqual(len(cm.records), 1) + self.assertIn("2948", cm.output[0]) + + def test_no_warning_when_value_is_safe(self): + # Capture records directly (assertNoLogs is only available on 3.10+). + records = [] + + class _Collect(logging.Handler): + def emit(self, record): + records.append(record) + + handler = _Collect(level=logging.WARNING) + vllm_utils.logger.addHandler(handler) + try: + resolve_repetition_penalty(1.0, prompt_embeds=True) + resolve_repetition_penalty(1.3, prompt_embeds=False) + finally: + vllm_utils.logger.removeHandler(handler) + self.assertEqual(records, []) + + +if __name__ == "__main__": + unittest.main()