Skip to content

fix(glm_asr): warn when vLLM dtype=fp16 (degraded output)#2993

Open
SuperMarioYL wants to merge 1 commit into
modelscope:mainfrom
SuperMarioYL:fix/glm-asr-vllm-dtype-guard
Open

fix(glm_asr): warn when vLLM dtype=fp16 (degraded output)#2993
SuperMarioYL wants to merge 1 commit into
modelscope:mainfrom
SuperMarioYL:fix/glm-asr-vllm-dtype-guard

Conversation

@SuperMarioYL

Copy link
Copy Markdown
Contributor

What & why

The GLM-ASR vLLM engine (GLMASRVLLMEngine) accepts dtype="fp16", but — exactly like Fun-ASR-Nano — fp16 can produce degraded or garbage transcription due to numerical overflow in the audio embedding path.

Fun-ASR-Nano already guards against this: inference_vllm.py logs a warning when dtype resolves to float16, pointing users at bf16/fp32 (added in #2980). The newer GLM-ASR engine has the same audio-tower-in-torch + LM-in-vLLM architecture and the same dtype_map default, but never got that warning — so a user who passes fp16 gets silently poor output with no hint about what went wrong.

This PR closes that gap.

Change

  • Add funasr/models/glm_asr/vllm_utils.py with warn_if_degraded_dtype() — a one-time warning when fp16 is requested. The requested value is still honoured (some GPUs only support fp16); the user is just told once why output may be poor.
  • Wire it into GLMASRVLLMEngine.__init__. Behaviour for every dtype is otherwise unchanged — self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16) keeps the existing default.

Fun-ASR-Nano emits this warning inline; I put it in a tiny helper so it can be unit-tested without a GPU/vLLM, matching the dependency-free fun_asr_nano/vllm_utils.py pattern already used for the prompt-embeds repetition-penalty guard (#2974). Happy to inline it instead if you'd prefer it to match the sibling line-for-line.

Tests

tests/test_glm_asr_vllm_dtype.py — covers fp16 warns (once), safe dtypes don't warn, and the value is returned unchanged. The helper is stdlib-only; tests run without a GPU or vLLM (Python 3.8+ compatible, no assertNoLogs).

$ python -m pytest tests/test_glm_asr_vllm_dtype.py -v
4 passed

The GLM-ASR vLLM engine accepts dtype='fp16' but, like Fun-ASR-Nano, fp16
can produce degraded or garbage transcription due to numerical overflow in
the audio embedding path. Fun-ASR-Nano already warns about this; GLM-ASR did
not, so users hitting it got silently poor output with no hint to switch to
bf16/fp32.

Add a small dependency-free helper that emits a one-time warning when fp16 is
requested (the value is still honoured for GPUs that only support fp16), wire
it into GLMASRVLLMEngine.__init__, and cover it with unit tests that run
without a GPU or vLLM.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warning guard to alert users when using the fp16 compute dtype with GLM-ASR, as it is known to degrade transcription quality due to numerical overflow. The changes include a dependency-free helper utility, its integration into the vLLM inference path, and corresponding unit tests. Feedback on the pull request highlights a potential issue where alternative representations of the dtype (such as torch.float16 or 'float16') would bypass the warning check and cause a dtype mismatch. It is recommended to normalize the dtype parameter to standard string keys before performing the check.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.


self.device = device
self.torch_dtype = dtype_map.get(dtype, torch.bfloat16)
self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a user passes dtype as "float16" or a torch.dtype object like torch.float16, it will bypass the warning check (which expects exactly "fp16"). Additionally, because "float16" is not in dtype_map, self.torch_dtype would fallback to torch.bfloat16 while vLLM would still receive "float16", leading to a dtype mismatch between the audio tower and the language model.\n\nNormalizing the dtype parameter to standard string keys ("bf16", "fp16", "fp32") at the beginning of __init__ resolves both issues robustly.

        dtype_str = str(dtype).lower().strip() if dtype is not None else ""\n        if "bfloat16" in dtype_str or "bf16" in dtype_str:\n            dtype = "bf16"\n        elif "float16" in dtype_str or "fp16" in dtype_str:\n            dtype = "fp16"\n        elif "float32" in dtype_str or "fp32" in dtype_str:\n            dtype = "fp32"\n        self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant