Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion funasr/models/glm_asr/inference_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ def __init__(self, model_dir, device="cuda:0", dtype="bf16",
from vllm import LLM
from transformers import AutoProcessor, AutoConfig, AutoModel as HFAutoModel

from funasr.models.glm_asr.vllm_utils import warn_if_degraded_dtype

self.device = device
self.torch_dtype = dtype_map.get(dtype, torch.bfloat16)
self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a user passes dtype as "float16" or a torch.dtype object like torch.float16, it will bypass the warning check (which expects exactly "fp16"). Additionally, because "float16" is not in dtype_map, self.torch_dtype would fallback to torch.bfloat16 while vLLM would still receive "float16", leading to a dtype mismatch between the audio tower and the language model.\n\nNormalizing the dtype parameter to standard string keys ("bf16", "fp16", "fp32") at the beginning of __init__ resolves both issues robustly.

        dtype_str = str(dtype).lower().strip() if dtype is not None else ""\n        if "bfloat16" in dtype_str or "bf16" in dtype_str:\n            dtype = "bf16"\n        elif "float16" in dtype_str or "fp16" in dtype_str:\n            dtype = "fp16"\n        elif "float32" in dtype_str or "fp32" in dtype_str:\n            dtype = "fp32"\n        self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16)

self.model_dir = model_dir

logger.info(f"Loading GLM-ASR audio components from {model_dir}")
Expand Down
43 changes: 43 additions & 0 deletions funasr/models/glm_asr/vllm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Helpers for the GLM-ASR vLLM serving path.

Kept dependency-free (standard library only) so the dtype guard can be unit
tested without a CUDA device, a torch build, or a vLLM installation.
"""

import logging

logger = logging.getLogger("funasr.glm_asr.vllm")

# Compute dtype that is known to degrade GLM-ASR transcription quality.
DEGRADED_DTYPE = "fp16"

# Warn only once per process so batch loops do not spam the log.
_warned_fp16 = False


def warn_if_degraded_dtype(dtype):
"""Warn once when a compute dtype is known to degrade GLM-ASR output.

``fp16`` can produce degraded or garbage transcription for GLM-ASR
(numerical overflow in the audio embedding path), matching the documented
Fun-ASR-Nano behaviour. The value is still honoured -- some GPUs only
support fp16 -- but the caller is warned once about why output may be poor.

Args:
dtype: Requested compute dtype string ("bf16", "fp16", "fp32").

Returns:
``dtype`` unchanged, so callers can wrap the value inline.
"""
global _warned_fp16

if dtype == DEGRADED_DTYPE and not _warned_fp16:
logger.warning(
"dtype='fp16' can produce degraded or garbage transcription for "
"GLM-ASR (numerical overflow in the audio embedding path). "
"Use dtype='bf16' (recommended) or dtype='fp32'. On GPUs without "
"bfloat16 support (e.g. NVIDIA V100), use 'fp32'."
)
_warned_fp16 = True

return dtype
60 changes: 60 additions & 0 deletions tests/test_glm_asr_vllm_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Unit tests for the GLM-ASR vLLM fp16 degraded-output guard.

Regression guard: ``fp16`` can produce degraded or garbage transcription for
GLM-ASR (numerical overflow in the audio embedding path), mirroring the
documented Fun-ASR-Nano behaviour. Requesting it must warn once so users are
not silently handed poor output, while leaving the requested value untouched.

The helper is dependency-free, so these tests run without a GPU, torch, or
vLLM.
"""

import logging
import unittest

from funasr.models.glm_asr import vllm_utils
from funasr.models.glm_asr.vllm_utils import warn_if_degraded_dtype


class TestWarnIfDegradedDtype(unittest.TestCase):
def setUp(self):
# Reset the once-per-process warning flag so each test is independent.
vllm_utils._warned_fp16 = False

def test_returns_value_unchanged(self):
for dtype in ("bf16", "fp16", "fp32", "something-else"):
self.assertEqual(warn_if_degraded_dtype(dtype), dtype)

def test_fp16_warns(self):
with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm:
warn_if_degraded_dtype("fp16")
self.assertEqual(len(cm.records), 1)
self.assertIn("fp16", cm.output[0])

def test_fp16_warns_only_once(self):
with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm:
warn_if_degraded_dtype("fp16")
# Subsequent calls must not emit additional warnings.
warn_if_degraded_dtype("fp16")
self.assertEqual(len(cm.records), 1)

def test_safe_values_do_not_warn(self):
# Capture records directly (assertNoLogs is only available on 3.10+).
records = []

class _Collect(logging.Handler):
def emit(self, record):
records.append(record)

handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
try:
warn_if_degraded_dtype("bf16")
warn_if_degraded_dtype("fp32")
finally:
vllm_utils.logger.removeHandler(handler)
self.assertEqual(records, [])


if __name__ == "__main__":
unittest.main()