Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions funasr/bin/_server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def _process_vllm(audio_data, sr, language=None, hotwords=None, use_spk=False):
if not seg_audios:
return {"text": "", "segments": [], "duration": len(audio_data)/sr}

# vLLM generate with repetition_penalty
gen_kwargs = {"max_new_tokens": 500, "repetition_penalty": 1.3}
# repetition_penalty is left at the neutral 1.0: the Fun-ASR-Nano vLLM
# engine runs in prompt-embeds mode, where any other value crashes the
# CUDA kernel (see issue #2948 and fun_asr_nano.vllm_utils).
gen_kwargs = {"max_new_tokens": 500, "repetition_penalty": 1.0}
if language:
gen_kwargs["language"] = language
if hotwords:
Expand Down
5 changes: 4 additions & 1 deletion funasr/models/fun_asr_nano/inference_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def generate(
except ImportError:
from vllm.inputs.data import EmbedsPrompt

from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to generate.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)


if isinstance(inputs, (str, np.ndarray, torch.Tensor)):
inputs = [inputs]

Expand All @@ -535,7 +537,8 @@ def generate(
temperature=temperature,
top_p=top_p,
top_k=top_k if top_k > 0 else -1,
repetition_penalty=repetition_penalty,
# Prompt-embeds mode has no token IDs to penalize; see #2948.
repetition_penalty=resolve_repetition_penalty(repetition_penalty),
skip_special_tokens=True,
)

Expand Down
7 changes: 6 additions & 1 deletion funasr/models/fun_asr_nano/inference_vllm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def _process_one(self, audio_path, **kwargs):
except ImportError:
from vllm.inputs.data import EmbedsPrompt

from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to _process_one.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)


prompts = []
for seg_audio in segment_audios:
seg_tensor = torch.from_numpy(seg_audio).float()
Expand All @@ -219,7 +221,10 @@ def _process_one(self, audio_path, **kwargs):
params = SamplingParams(
max_tokens=kwargs.get("max_new_tokens", 512),
temperature=0.0,
repetition_penalty=1.3,
# Prompt-embeds mode has no token IDs to penalize; see #2948.
repetition_penalty=resolve_repetition_penalty(
kwargs.get("repetition_penalty", 1.0)
),
skip_special_tokens=True,
)

Expand Down
11 changes: 9 additions & 2 deletions funasr/models/fun_asr_nano/inference_vllm_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,15 @@ def streaming_generate(self, audio_input, chunk_ms=None, rollback_chars=None,
chunk_samples = int(self.sample_rate * chunk_ms / 1000)
num_chunks = (total_samples + chunk_samples - 1) // chunk_samples

params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature,
repetition_penalty=1.3, skip_special_tokens=True)
from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to streaming_generate.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)


# Prompt-embeds mode has no token IDs to penalize; see #2948.
params = SamplingParams(
max_tokens=max_new_tokens, temperature=temperature,
repetition_penalty=resolve_repetition_penalty(
kwargs.get("repetition_penalty", 1.0)
),
skip_special_tokens=True)

# Two-stage approach for long audio:
# Stage 1: batch first N chunks fresh (no prev_text) to find stable output
Expand Down
58 changes: 58 additions & 0 deletions funasr/models/fun_asr_nano/vllm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Helpers shared by the Fun-ASR-Nano vLLM serving paths.

Kept dependency-free (standard library only) so it can be imported and unit
tested without a CUDA device or a vLLM installation.
"""

import logging

logger = logging.getLogger("funasr.fun_asr_nano.vllm")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a hardcoded string for the logger name ("funasr.fun_asr_nano.vllm") bypasses Python's standard module-based logging hierarchy. It is highly recommended to use __name__ instead, which automatically aligns the logger with the module's package path (funasr.models.fun_asr_nano.vllm_utils) and makes logging configuration much more manageable and standard.

Suggested change
logger = logging.getLogger("funasr.fun_asr_nano.vllm")
logger = logging.getLogger(__name__)


# A repetition penalty of 1.0 is the identity value, i.e. "no penalty".
NEUTRAL_REPETITION_PENALTY = 1.0

# Warn only once per process so streaming/batch loops do not spam the log.
_warned_prompt_embeds = False


def resolve_repetition_penalty(repetition_penalty, *, prompt_embeds=True):
"""Return a repetition penalty that is safe for the requested vLLM mode.

Fun-ASR-Nano feeds vLLM precomputed audio/text *embeddings* with
``enable_prompt_embeds=True``. In that mode a request carries no prompt
token IDs. vLLM applies ``repetition_penalty`` by scattering over the
prompt's token IDs, so any value other than 1.0 indexes an empty token-id
tensor and aborts the engine with a CUDA
``scatter gather kernel index out of bounds`` assertion (issue #2948).

When ``prompt_embeds`` is True we therefore force the penalty back to the
neutral value and warn once. With ``prompt_embeds=False`` (regular
token-prompt decoding) the requested value is passed through unchanged.

Args:
repetition_penalty: Penalty requested by the caller. ``None`` is
treated as "unset" and maps to the neutral value.
prompt_embeds: Whether the request runs in vLLM prompt-embeds mode.

Returns:
A repetition penalty that will not crash the engine.
"""
global _warned_prompt_embeds

if repetition_penalty is None:
return NEUTRAL_REPETITION_PENALTY

if prompt_embeds and repetition_penalty != NEUTRAL_REPETITION_PENALTY:
if not _warned_prompt_embeds:
logger.warning(
"repetition_penalty=%s is not supported in vLLM prompt-embeds "
"mode (no prompt token IDs to penalize) and would trigger a CUDA "
"scatter index-out-of-bounds crash; using repetition_penalty=%s "
"instead. See https://github.com/modelscope/FunASR/issues/2948.",
repetition_penalty,
NEUTRAL_REPETITION_PENALTY,
)
_warned_prompt_embeds = True
return NEUTRAL_REPETITION_PENALTY

return repetition_penalty
74 changes: 74 additions & 0 deletions tests/test_fun_asr_nano_repetition_penalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Unit tests for Fun-ASR-Nano vLLM repetition-penalty handling.

Regression guard for issue #2948: a repetition penalty other than 1.0 is
incompatible with vLLM prompt-embeds mode and aborts the engine with a CUDA
"scatter gather index out of bounds" assertion. The serving paths must never
forward such a value to ``SamplingParams`` while ``enable_prompt_embeds=True``.

The helper is dependency-free, so these tests run without a GPU or vLLM.
"""

import logging
import unittest

from funasr.models.fun_asr_nano import vllm_utils
from funasr.models.fun_asr_nano.vllm_utils import (
NEUTRAL_REPETITION_PENALTY,
resolve_repetition_penalty,
)


class TestResolveRepetitionPenalty(unittest.TestCase):
def setUp(self):
# Reset the once-per-process warning flag so each test is independent.
vllm_utils._warned_prompt_embeds = False

def test_neutral_value_passes_through(self):
self.assertEqual(resolve_repetition_penalty(1.0), 1.0)

def test_none_maps_to_neutral(self):
self.assertEqual(
resolve_repetition_penalty(None), NEUTRAL_REPETITION_PENALTY
)

def test_nonneutral_is_clamped_in_prompt_embeds_mode(self):
# The exact value that triggers the #2948 crash.
self.assertEqual(
resolve_repetition_penalty(1.3, prompt_embeds=True),
NEUTRAL_REPETITION_PENALTY,
)

def test_nonneutral_preserved_for_token_prompts(self):
# Regular token-prompt decoding can safely apply the penalty.
self.assertEqual(
resolve_repetition_penalty(1.3, prompt_embeds=False), 1.3
)

def test_warns_once_in_prompt_embeds_mode(self):
with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm:
resolve_repetition_penalty(1.3, prompt_embeds=True)
# Subsequent clamps must not emit additional warnings.
resolve_repetition_penalty(1.5, prompt_embeds=True)
self.assertEqual(len(cm.records), 1)
self.assertIn("2948", cm.output[0])

def test_no_warning_when_value_is_safe(self):
# Capture records directly (assertNoLogs is only available on 3.10+).
records = []

class _Collect(logging.Handler):
def emit(self, record):
records.append(record)

handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.removeHandler(handler)
Comment on lines +63 to +69

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure the test is robust and independent of any external logging configuration (which might set the logger level to ERROR or higher and cause false positives), it is safer to explicitly set the logger's level to logging.WARNING (or lower) during the test execution and restore it afterward.

Suggested change
handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.removeHandler(handler)
handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
original_level = vllm_utils.logger.level
vllm_utils.logger.setLevel(logging.WARNING)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.setLevel(original_level)
vllm_utils.logger.removeHandler(handler)

self.assertEqual(records, [])


if __name__ == "__main__":
unittest.main()