diff --git a/autotest/config_h.yml b/autotest/config_h.yml index bbae1ced5e..dcac475bbf 100644 --- a/autotest/config_h.yml +++ b/autotest/config_h.yml @@ -45,6 +45,12 @@ config: internlm/Intern-S1-Pro-FP8: dp: 16 ep: 16 + Qwen/Qwen3.5-397B-A17B: + dp: 4 + ep: 8 + Qwen/Qwen3.5-397B-A17B-FP8: + dp: 4 + ep: 8 cp_tp: Qwen/Qwen3-235B-A22B-Thinking-2507: @@ -122,6 +128,8 @@ pytorch_chat_model: - Qwen/Qwen3.5-35B-A3B - Qwen/Qwen3.5-35B-A3B-FP8 - Qwen/Qwen3.5-122B-A10B + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b @@ -159,6 +167,8 @@ pytorch_vl_model: - Qwen/Qwen3.5-35B-A3B - Qwen/Qwen3.5-35B-A3B-FP8 - Qwen/Qwen3.5-122B-A10B + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b @@ -283,6 +293,8 @@ pytorch_quantization: - internlm/Intern-S1 - internlm/Intern-S1-mini - internlm/Intern-S1-Pro-FP8 + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 no_kvint8: - zai-org/GLM-4.7-Flash - zai-org/GLM-5-FP8 @@ -293,6 +305,10 @@ pytorch_quantization: - Qwen/Qwen3.5-122B-A10B - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - internlm/Intern-S1-Pro-FP8 + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 + fp8: + - Qwen/Qwen3.5-397B-A17B longtext_benchmark_model: - Qwen/Qwen3-30B-A3B @@ -335,13 +351,19 @@ evaluate_model: - deepseek-ai/DeepSeek-V3.1 - zai-org/GLM-5-FP8 - internlm/Intern-S1-Pro-FP8 + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 longtext_evaluate_model: - Qwen/Qwen3.5-35B-A3B + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 mtp_evaluate_model: - Qwen/Qwen3.5-35B-A3B - Qwen/Qwen3.5-35B-A3B-FP8 + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 mllm_evaluate_model: - OpenGVLab/InternVL3_5-38B @@ -352,3 +374,5 @@ mllm_evaluate_model: - Qwen/Qwen3.5-122B-A10B - internlm/Intern-S1 - internlm/Intern-S1-mini + - Qwen/Qwen3.5-397B-A17B + - Qwen/Qwen3.5-397B-A17B-FP8 diff --git a/autotest/evaluate/test_api_evaluate.py b/autotest/evaluate/test_api_evaluate.py index c94896678f..20af79e2e7 100644 --- a/autotest/evaluate/test_api_evaluate.py +++ b/autotest/evaluate/test_api_evaluate.py @@ -63,18 +63,21 @@ def _run_proxy_distributed_test(config, worker_id, test_type='infer', manager=None, - eval_config_name='default'): + eval_config_name='default', + eval_subpath=None): assert manager is not None, 'Manager instance must be provided' - if 'gpt' in run_config.get('model', '').lower(): - eval_config_name = 'gpt' - elif 'intern-s1-pro' in run_config.get('model', '').lower(): - eval_config_name = 'intern-s1-pro' - elif 'qwen3.5' in run_config.get('model', '').lower(): - eval_config_name = 'qwen3.5' + if eval_subpath is None: + if eval_config_name == 'default': + if 'gpt' in run_config.get('model', '').lower(): + eval_config_name = 'gpt' + elif 'intern-s1-pro' in run_config.get('model', '').lower(): + eval_config_name = 'intern-s1-pro' + elif 'qwen3.5' in run_config.get('model', '').lower(): + eval_config_name = 'qwen3.5' - if str(config.get('env_tag')) == 'ascend': - eval_config_name = f'{eval_config_name}-2batch' + if str(config.get('env_tag')) == 'ascend': + eval_config_name = f'{eval_config_name}-2batch' preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {}) model_name = run_config['model'] @@ -88,6 +91,9 @@ def _run_proxy_distributed_test(config, api_server.wait_until_ready() print(f'πŸ§ͺ Master node executing {test_type} test ({eval_config_name})...') eval_path = config.get('eval_path') + if eval_subpath: + eval_path = os.path.join(eval_path, eval_subpath) + os.makedirs(eval_path, exist_ok=True) case_name = get_case_str_by_config(run_config) extra_config = {'max-num-workers': 16} @@ -98,6 +104,7 @@ def _run_proxy_distributed_test(config, port=constant.PROXY_PORT, test_type=test_type, extra_config=extra_config, + eval_config_name=eval_config_name, **preset_config) assert result, f'❌ {test_type} test failed: {msg}' print(f'βœ… {test_type} test passed') @@ -282,7 +289,30 @@ def test_pytorch_restful_tp2(config, run_config, worker_id): ), ) def test_pytorch_restful_tp2_longtext(config, run_config, worker_id): - run_eval_test(config, run_config, worker_id, 'infer', eval_config_name='longtext-256k') + run_eval_test(config, run_config, worker_id, 'infer', eval_subpath='longtext', eval_config_name='longtext-256k') + + +@pytest.mark.infer +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize( + 'run_config', + get_func_config_list( + 'pytorch', + {'dp': 4, 'ep': 8}, + func_type='longtext_evaluate', + extra={'session_len': 400000}, + ), +) +def test_pytorch_restful_distributed_dp4ep8_longtext(shared_proxy_manager, config, run_config, worker_id): + _run_proxy_distributed_test(config=config, + run_config=run_config, + worker_id=worker_id, + test_type='infer', + manager=shared_proxy_manager, + eval_config_name='longtext-256k', + eval_subpath='longtext') @pytest.mark.infer @@ -398,6 +428,19 @@ def test_pytorch_restful_distributed_dpep8(shared_proxy_manager, config, run_con manager=shared_proxy_manager) +@pytest.mark.infer +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('run_config', get_func_config_list('pytorch', {'dp': 4, 'ep': 8}, func_type='evaluate')) +def test_pytorch_restful_distributed_dp4ep8(shared_proxy_manager, config, run_config, worker_id): + _run_proxy_distributed_test(config=config, + run_config=run_config, + worker_id=worker_id, + test_type='infer', + manager=shared_proxy_manager) + + @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep16 @@ -515,6 +558,15 @@ def test_pytorch_eval_distributed_dpep8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') +@pytest.mark.eval +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('run_config', get_func_config_list('pytorch', {'dp': 4, 'ep': 8}, func_type='evaluate')) +def test_pytorch_eval_distributed_dp4ep8(config, run_config, worker_id): + run_eval_test(config, run_config, worker_id, 'eval') + + @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep16 @@ -538,7 +590,24 @@ def test_pytorch_eval_distributed_dpep16(config, run_config, worker_id): ), ) def test_pytorch_eval_tp2_longtext(config, run_config, worker_id): - run_eval_test(config, run_config, worker_id, 'eval', eval_config_name='longtext-256k') + run_eval_test(config, run_config, worker_id, 'eval', eval_subpath='longtext', eval_config_name='longtext-256k') + + +@pytest.mark.eval +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize( + 'run_config', + get_func_config_list( + 'pytorch', + {'dp': 4, 'ep': 8}, + func_type='longtext_evaluate', + extra={'session_len': 400000}, + ), +) +def test_pytorch_eval_distributed_dp4ep8_longtext(config, run_config, worker_id): + run_eval_test(config, run_config, worker_id, 'eval', eval_subpath='longtext', eval_config_name='longtext-256k') @pytest.mark.eval diff --git a/autotest/evaluate/test_mllm_api_evaluate.py b/autotest/evaluate/test_mllm_api_evaluate.py index 796a333155..c5b8e7ea34 100644 --- a/autotest/evaluate/test_mllm_api_evaluate.py +++ b/autotest/evaluate/test_mllm_api_evaluate.py @@ -1,15 +1,23 @@ import os +import time import pytest import utils.constant as constant from utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid from utils.evaluate_utils import mllm_eval_test +from utils.proxy_distributed_utils import ApiServerPerTest, proxy_worker_node_wait from utils.run_restful_chat import start_openai_service, start_proxy_server, stop_restful_api, terminate_restful_api -def run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default'): +def run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default', eval_subpath=None): + if eval_config_name == 'default': + if 'qwen3.5' in run_config.get('model', '').lower(): + eval_config_name = 'qwen3.5' extra_config = constant.MLLM_EVAL_CONFIGS.get(eval_config_name, {}) eval_path = config.get('mllm_eval_path') + if eval_subpath: + eval_path = os.path.join(eval_path, eval_subpath) + os.makedirs(eval_path, exist_ok=True) case_name = get_case_str_by_config(run_config) if test_type == 'infer': proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), constant.PROXY_PORT, @@ -66,6 +74,57 @@ def run_openai_service_start(i): stop_restful_api(proxy_pid, proxy_process) +def _run_proxy_distributed_mllm_test( + config, + run_config, + worker_id, + test_type='infer', + manager=None, + eval_config_name='default'): + assert manager is not None, 'Manager instance must be provided' + + if eval_config_name == 'default': + if 'qwen3.5' in run_config.get('model', '').lower(): + eval_config_name = 'qwen3.5' + + if str(config.get('env_tag')) == 'ascend': + eval_config_name = f'{eval_config_name}-2batch' + + preset_config = constant.MLLM_EVAL_CONFIGS.get(eval_config_name, {}) + model_name = run_config['model'] + model_path = os.path.join(config['model_path'], model_name) + + api_server = ApiServerPerTest(proxy_manager=manager, config=config, run_config=run_config) + api_server.start() + + try: + if manager.is_master: + api_server.wait_until_ready() + print(f'πŸ§ͺ Master node executing mllm {test_type} test ({eval_config_name})...') + eval_path = config.get('mllm_eval_path') + case_name = get_case_str_by_config(run_config) + extra_config = {'api-nproc': 16} + extra_config.update(preset_config) + + result, msg = mllm_eval_test(model_path, + eval_path, + case_name, + port=constant.PROXY_PORT, + test_type=test_type, + extra_config=extra_config) + assert result, f'❌ mllm {test_type} test failed: {msg}' + print(f'βœ… mllm {test_type} test passed') + + else: + print(f'⏸️ Worker node {manager.node_rank} waiting for master to complete mllm test...') + proxy_worker_node_wait(manager, timeout_minutes=4880) + + finally: + api_server.cleanup() + if manager.is_master: + time.sleep(1) + + def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, @@ -247,3 +306,25 @@ def test_pytorch_eval_tp8(config, run_config, worker_id): @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_eval_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') + + +@pytest.mark.infer +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 4, 'ep': 8})) +def test_pytorch_vl_restful_distributed_dp4ep8(shared_proxy_manager, config, run_config, worker_id): + _run_proxy_distributed_mllm_test(config=config, + run_config=run_config, + worker_id=worker_id, + test_type='infer', + manager=shared_proxy_manager) + + +@pytest.mark.eval +@pytest.mark.pytorch +@pytest.mark.gpu_num_distributed_dp4ep8 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 4, 'ep': 8})) +def test_pytorch_vl_eval_distributed_dp4ep8(config, run_config, worker_id): + run_eval_test(config, run_config, worker_id, 'eval') diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py index df43e43f0c..e73d0e50ff 100644 --- a/autotest/utils/config_utils.py +++ b/autotest/utils/config_utils.py @@ -63,6 +63,14 @@ def get_func_config_list(backend: str, run_configs = [] dtype = 'float16' if not is_bf16_supported(device) else None + quantization_config = config.get(f'{backend}_quantization', {}) + fp8_model_list = quantization_config.get('fp8', []) + + def get_model_extra_params(model: str) -> dict: + if model in fp8_model_list: + return {'model-format': 'fp8'} + return {} + for communicator in _get_communicator_list(config, backend, parallel_config): for model in base_case_list: for quant_policy in [0, 4, 8]: @@ -95,6 +103,13 @@ def get_func_config_list(backend: str, run_config['extra_params']['dtype'] = dtype if device != 'cuda': run_config['extra_params']['device'] = device + + model_extra_params = get_model_extra_params(model) + if model_extra_params and quant_policy == 0: + run_config_with_format = copy.deepcopy(run_config) + run_config_with_format['extra_params'].update(model_extra_params) + run_configs.append(run_config_with_format) + run_configs.append(run_config) for run_config in run_configs: @@ -109,6 +124,10 @@ def get_func_config_list(backend: str, run_config['extra_params']['cache-max-entry-count'] = 0.9 run_config['extra_params']['max-batch-size'] = 128 + if 'Qwen3.5-397B-A17B' in run_config['model']: + run_config['extra_params']['max-batch-size'] = 256 + run_config['extra_params']['cache-max-entry-count'] = 0.9 + if (func_type == 'evaluate' and 'session_len' not in extra and 'session-len' not in extra and 'Qwen3.5' not in run_config['model']): run_config['extra_params']['session_len'] = 65536 @@ -143,7 +162,7 @@ def get_func_config_list(backend: str, and func_type in ('benchmark', 'longtext_benchmark')): run_config['extra_params']['model-format'] = 'mxfp4' - if func_type == 'mtp_evaluate' and 'Qwen3.5' in run_config['model']: + if func_type == 'mtp_evaluate': run_config['extra_params'].update({ 'reasoning-parser': 'qwen-qwq', 'speculative-algorithm': 'qwen3_5_mtp', @@ -521,6 +540,9 @@ def get_case_str_by_config(run_config: dict[str, Any], is_simple: bool = True) - # Get last section of model name, compatible with model name contains '/' pure_model_name = model_name.split('/')[-1].replace('_', '-') extra_params_case = '' + model_format = extra_params.get('model-format') + if model_format: + extra_params_case += f'_{model_format}' if not is_simple: for k, v in extra_params.items(): if len(v) > 10: @@ -535,12 +557,23 @@ def parse_config_by_case(case_str: str) -> dict[str, Any]: """Parse run config dict from case name string (fix split & type convert bug)""" case_parts = case_str.split('_') - # Parse fixed field & reassemble dynamic parallel config + if len(case_parts) < 4: + raise ValueError(f'Invalid case string: {case_str}') + backend = case_parts[0] model = case_parts[1] communicator = case_parts[2] - quant_policy = int(case_parts[-1]) - parallel_parts = case_parts[3:-1] + + quant_idx = None + for i in range(len(case_parts) - 1, 2, -1): + if case_parts[i].isdigit(): + quant_idx = i + break + if quant_idx is None: + raise ValueError(f'No numeric quant policy found in case string: {case_str}') + + quant_policy = int(case_parts[quant_idx]) + parallel_parts = case_parts[3:quant_idx] # Convert parallel str to dict, e.g: ['tp1','pp2'] -> {'tp':1, 'pp':2} parallel_config = {} diff --git a/autotest/utils/evaluate_utils.py b/autotest/utils/evaluate_utils.py index 07025b9b56..91a832f876 100644 --- a/autotest/utils/evaluate_utils.py +++ b/autotest/utils/evaluate_utils.py @@ -4,6 +4,8 @@ import os import subprocess import time +import traceback +from contextlib import contextmanager import allure import pandas as pd @@ -13,6 +15,50 @@ from utils.constant import DEFAULT_PORT, DEFAULT_SERVER, EVAL_RUN_CONFIG +@contextmanager +def _mmengine_lazy_allow_lazyattr_call(): + try: + from mmengine.config import lazy as mm_lazy + except ImportError: + yield + return + cls = mm_lazy.LazyAttr + orig = cls.__call__ + + def _call(self, *args, **kwargs): + fn = self.build() + if not callable(fn): + raise RuntimeError() + return fn(*args, **kwargs) + + cls.__call__ = _call + try: + yield + finally: + cls.__call__ = orig + + +def _sync_ruler_tokenizer_model(cfg, model_path): + tokenizer = os.environ.get('TOKENIZER_MODEL', model_path) + if not getattr(cfg, 'datasets', None): + return + for dataset in cfg.datasets: + if isinstance(dataset, dict) and 'tokenizer_model' in dataset: + dataset['tokenizer_model'] = tokenizer + + +def _is_fp8_case(case_name: str) -> bool: + return case_name.endswith('_fp8') + + +def _should_skip_num_workers_override(eval_config_name: str, case_name: str) -> bool: + if eval_config_name in ('longtext-256k', 'longtext-512k'): + return True + if _is_fp8_case(case_name): + return True + return False + + def write_to_summary(case_name, result, msg, metrics, result_dir): status = 'βœ… PASS' if result else f'❌ FAIL {msg}' @@ -186,7 +232,8 @@ def eval_test(model_path, if not os.path.exists(config_file): return False, f'Config file {config_file} not found' - cfg = Config.fromfile(config_file) + with _mmengine_lazy_allow_lazyattr_call(): + cfg = Config.fromfile(config_file) cfg.MODEL_NAME = case_name cfg.MODEL_PATH = model_path @@ -202,8 +249,12 @@ def eval_test(model_path, for key, value in kwargs.items(): model_cfg[key] = value - cfg.NUM_WORKERS = extra_config.get('max-num-workers', 8) - cfg.infer['partitioner']['num_worker'] = extra_config.get('max-num-workers', 8) + _sync_ruler_tokenizer_model(cfg, model_path) + + if not _should_skip_num_workers_override(eval_config_name, case_name): + cfg.NUM_WORKERS = extra_config.get('max-num-workers', 8) + cfg.infer['partitioner']['num_worker'] = extra_config.get( + 'max-num-workers', 8) cfg.dump(temp_config_path) print(f'Modified config saved to: {temp_config_path}') @@ -213,7 +264,8 @@ def eval_test(model_path, llm_summary(case_name, False, error_msg, work_dir, eval_path) return False, error_msg - cfg = Config.fromfile(temp_config_path) + with _mmengine_lazy_allow_lazyattr_call(): + cfg = Config.fromfile(temp_config_path) print(f'Using existing temp config file: {temp_config_path}') eval_run_config = EVAL_RUN_CONFIG eval_case_name = get_case_str_by_config(eval_run_config) @@ -244,6 +296,8 @@ def eval_test(model_path, evaluator['llm_evaluator']['judge_cfg']['openai_api_base'] = cfg.JUDGE_API_BASE evaluator['llm_evaluator']['judge_cfg']['tokenizer_path'] = cfg.JUDGE_MODEL_PATH + _sync_ruler_tokenizer_model(cfg, model_path) + cfg.dump(temp_config_path) print(f'Modified config for eval stage saved to: {temp_config_path}') @@ -262,7 +316,8 @@ def eval_test(model_path, return result, stderr except Exception as e: print(f'Error occurred: {e}') - return False, f'Error occurred: {e}' + print(traceback.format_exc()) + return False, f'Error occurred: {e}\n{traceback.format_exc()}' finally: os.chdir(original_cwd) print(f'Returned to directory: {original_cwd}') @@ -295,7 +350,7 @@ def mllm_eval_test(model_path, eval_path, case_name, port=DEFAULT_PORT, test_typ extra_config_str = get_cli_str(extra_config) if test_type == 'infer': - cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:{port}/v1 --reuse --work-dir {work_dir} --mode infer {extra_config_str}' # noqa + cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:{port}/v1 --reuse --work-dir {work_dir} --timeout 7200 --mode infer {extra_config_str}' # noqa elif test_type == 'eval': cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:empty/v1 --reuse --work-dir {work_dir} --api-nproc 32 --mode eval --judge turbomind_Qwen2.5-32B-Instruct_nccl_tp2_0 --judge-base-url http://{DEFAULT_SERVER}:{port}/v1' # noqa diff --git a/autotest/utils/proxy_distributed_utils.py b/autotest/utils/proxy_distributed_utils.py index fa8afe7997..149d81acfc 100644 --- a/autotest/utils/proxy_distributed_utils.py +++ b/autotest/utils/proxy_distributed_utils.py @@ -52,8 +52,8 @@ def check_nodes_status(host: str, proxy_port: int, model_name: str, expected_ins if should_print: basename = os.path.basename(model_name) print(f'πŸ“Š Check {check_count}: Model registration progress: ' - f'{ready_instances}/{expected_instances} instances ready ' - f'(Total reported: {total_instances})') + f'{ready_instances}/{expected_instances} nodes with model, ' + f'{total_instances}/{expected_instances} nodes seen by proxy') for node_url, node_info in nodes_data.items(): models = node_info.get('models', []) if model_name in models: @@ -61,15 +61,22 @@ def check_nodes_status(host: str, proxy_port: int, model_name: str, expected_ins else: print(f' ⏳ Instance {node_url} has not registered target model') - if ready_instances >= expected_instances: + if total_instances != expected_instances: if should_print: - print(f'🎯 All {expected_instances} API server instances have registered the target model') - return True, ready_instances - else: - if should_print: - print(f'⏳ Waiting for more instances to register... ({ready_instances}/{expected_instances})') + print(f'⏳ Waiting for proxy to see exactly {expected_instances} nodes ' + f'(dp-sized cluster); currently {total_instances}') return False, ready_instances + if ready_instances == expected_instances: + if should_print: + print(f'🎯 All {expected_instances} nodes registered the target model ' + f'(matches /nodes/status count)') + return True, ready_instances + + if should_print: + print(f'⏳ Waiting for all nodes to register model... ({ready_instances}/{expected_instances})') + return False, ready_instances + except Exception as e: if current_time - last_progress_print >= progress_print_interval: print(f'πŸ”§ Check {check_count}: Exception getting node status - {e}') @@ -229,7 +236,9 @@ def __init__(self, proxy_manager: ProxyDistributedManager, config: dict[str, Any self.node_count = int(os.getenv('NODE_COUNT', '1')) self.proc_per_node = int(os.getenv('PROC_PER_NODE', '1')) - self.expected_instances = self.node_count * self.proc_per_node + _pc = run_config.get('parallel_config') or {} + _dp = int(_pc.get('dp', 0) or 0) + self.expected_instances = _dp if _dp > 1 else 1 self.is_master = (self.node_rank == 0) self.api_process = None