diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a165be78f..d2a27fda6 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -9,6 +9,7 @@ import setproctitle import threading import collections +import multiprocessing from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -62,12 +63,14 @@ def __init__( async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.model_procs: List[multiprocessing.Process] = [] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): for tp_rank_id in range(self.vit_tp): - rpc_model = await start_model_process() + rpc_model, proc = await start_model_process() self.model_rpcs[dp_rank_id].append(rpc_model) + self.model_procs.append(proc) init_model_ret = [] for dp_rank_id in range(self.vit_dp): # async init model process @@ -187,7 +190,14 @@ async def loop_for_netio_req(self): logger.exception(str(e)) def clean_up(self): - return + for proc in getattr(self, "model_procs", []): + try: + if proc.is_alive(): + logger.info(f"Killing VIT model process {proc.pid}") + proc.kill() + proc.join(timeout=5) + except (ProcessLookupError, OSError): + pass def start_visual_process(args, pipe_writer): diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index ae3c4204d..db61407fd 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -9,6 +9,7 @@ from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_env_start_args from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer @@ -18,6 +19,7 @@ def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ @@ -52,7 +54,7 @@ async def start_model_process(): # 服务端需要调用客户端传入的event所以,客户端需要一个后台线程进行相关的处理。 conn._bg_thread = rpyc.BgServingThread(conn, sleep_interval=0.001) - return VisualModelRpcClient(conn) + return VisualModelRpcClient(conn), proc def _generate_unix_socket_path() -> str: diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 27275c1e8..5a8a9981a 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -12,6 +12,7 @@ import os import signal import time +import multiprocessing from lightllm.utils.net_utils import get_hostname_ip from .objs import VIT_Obj from typing import List @@ -94,12 +95,14 @@ async def register_to_config_server_loop(self, args: StartArgs): async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.model_procs: List[multiprocessing.Process] = [] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): for tp_rank_id in range(self.vit_tp): - rpc_model = await start_model_process() + rpc_model, proc = await start_model_process() self.model_rpcs[dp_rank_id].append(rpc_model) + self.model_procs.append(proc) init_model_ret = [] for dp_rank_id in range(self.vit_dp): # async init model process @@ -130,7 +133,14 @@ async def infer_images(self, dp_index: int, images, events): await VisualManager.infer_images(self, dp_index=dp_index, images=images, events=events) def clean_up(self): - return + for proc in getattr(self, "model_procs", []): + try: + if proc.is_alive(): + logger.info(f"Killing VIT model process {proc.pid}") + proc.kill() + proc.join(timeout=5) + except (ProcessLookupError, OSError): + pass def exposed_remote_infer_images(self, images: List[ImageItem], ref_event: threading.Event): try: