-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix(glm_asr): warn when vLLM dtype=fp16 (degraded output) #2993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
SuperMarioYL
wants to merge
1
commit into
modelscope:main
Choose a base branch
from
SuperMarioYL:fix/glm-asr-vllm-dtype-guard
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+106
−1
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| """Helpers for the GLM-ASR vLLM serving path. | ||
|
|
||
| Kept dependency-free (standard library only) so the dtype guard can be unit | ||
| tested without a CUDA device, a torch build, or a vLLM installation. | ||
| """ | ||
|
|
||
| import logging | ||
|
|
||
| logger = logging.getLogger("funasr.glm_asr.vllm") | ||
|
|
||
| # Compute dtype that is known to degrade GLM-ASR transcription quality. | ||
| DEGRADED_DTYPE = "fp16" | ||
|
|
||
| # Warn only once per process so batch loops do not spam the log. | ||
| _warned_fp16 = False | ||
|
|
||
|
|
||
| def warn_if_degraded_dtype(dtype): | ||
| """Warn once when a compute dtype is known to degrade GLM-ASR output. | ||
|
|
||
| ``fp16`` can produce degraded or garbage transcription for GLM-ASR | ||
| (numerical overflow in the audio embedding path), matching the documented | ||
| Fun-ASR-Nano behaviour. The value is still honoured -- some GPUs only | ||
| support fp16 -- but the caller is warned once about why output may be poor. | ||
|
|
||
| Args: | ||
| dtype: Requested compute dtype string ("bf16", "fp16", "fp32"). | ||
|
|
||
| Returns: | ||
| ``dtype`` unchanged, so callers can wrap the value inline. | ||
| """ | ||
| global _warned_fp16 | ||
|
|
||
| if dtype == DEGRADED_DTYPE and not _warned_fp16: | ||
| logger.warning( | ||
| "dtype='fp16' can produce degraded or garbage transcription for " | ||
| "GLM-ASR (numerical overflow in the audio embedding path). " | ||
| "Use dtype='bf16' (recommended) or dtype='fp32'. On GPUs without " | ||
| "bfloat16 support (e.g. NVIDIA V100), use 'fp32'." | ||
| ) | ||
| _warned_fp16 = True | ||
|
|
||
| return dtype |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """Unit tests for the GLM-ASR vLLM fp16 degraded-output guard. | ||
|
|
||
| Regression guard: ``fp16`` can produce degraded or garbage transcription for | ||
| GLM-ASR (numerical overflow in the audio embedding path), mirroring the | ||
| documented Fun-ASR-Nano behaviour. Requesting it must warn once so users are | ||
| not silently handed poor output, while leaving the requested value untouched. | ||
|
|
||
| The helper is dependency-free, so these tests run without a GPU, torch, or | ||
| vLLM. | ||
| """ | ||
|
|
||
| import logging | ||
| import unittest | ||
|
|
||
| from funasr.models.glm_asr import vllm_utils | ||
| from funasr.models.glm_asr.vllm_utils import warn_if_degraded_dtype | ||
|
|
||
|
|
||
| class TestWarnIfDegradedDtype(unittest.TestCase): | ||
| def setUp(self): | ||
| # Reset the once-per-process warning flag so each test is independent. | ||
| vllm_utils._warned_fp16 = False | ||
|
|
||
| def test_returns_value_unchanged(self): | ||
| for dtype in ("bf16", "fp16", "fp32", "something-else"): | ||
| self.assertEqual(warn_if_degraded_dtype(dtype), dtype) | ||
|
|
||
| def test_fp16_warns(self): | ||
| with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: | ||
| warn_if_degraded_dtype("fp16") | ||
| self.assertEqual(len(cm.records), 1) | ||
| self.assertIn("fp16", cm.output[0]) | ||
|
|
||
| def test_fp16_warns_only_once(self): | ||
| with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: | ||
| warn_if_degraded_dtype("fp16") | ||
| # Subsequent calls must not emit additional warnings. | ||
| warn_if_degraded_dtype("fp16") | ||
| self.assertEqual(len(cm.records), 1) | ||
|
|
||
| def test_safe_values_do_not_warn(self): | ||
| # Capture records directly (assertNoLogs is only available on 3.10+). | ||
| records = [] | ||
|
|
||
| class _Collect(logging.Handler): | ||
| def emit(self, record): | ||
| records.append(record) | ||
|
|
||
| handler = _Collect(level=logging.WARNING) | ||
| vllm_utils.logger.addHandler(handler) | ||
| try: | ||
| warn_if_degraded_dtype("bf16") | ||
| warn_if_degraded_dtype("fp32") | ||
| finally: | ||
| vllm_utils.logger.removeHandler(handler) | ||
| self.assertEqual(records, []) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user passes
dtypeas"float16"or atorch.dtypeobject liketorch.float16, it will bypass the warning check (which expects exactly"fp16"). Additionally, because"float16"is not indtype_map,self.torch_dtypewould fallback totorch.bfloat16while vLLM would still receive"float16", leading to a dtype mismatch between the audio tower and the language model.\n\nNormalizing thedtypeparameter to standard string keys ("bf16","fp16","fp32") at the beginning of__init__resolves both issues robustly.