diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 1cd04fec..7bdb9b64 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -43,7 +43,7 @@ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None): output_objects = [object] group_size = 1 if dist.is_available() and dist.is_initialized(): - if Platform.device_prefix() == 'npu': + if Platform.device_prefix() == 'npu' and not device_mesh.has_dim('fsdp'): # On NPU, letting Python object collectives use the default HCCL # group previously hung in 8-card metric collection at # ``dist.all_gather_object(...)``. Reuse Megatron's dedicated Gloo @@ -51,8 +51,9 @@ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None): # variant, otherwise the rank span for metric aggregation is wrong. if importlib.util.find_spec('megatron.core') is not None: from megatron.core import parallel_state as mpu - process_group = mpu.get_data_parallel_group_gloo( - with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) + if mpu.model_parallel_is_initialized(): + process_group = mpu.get_data_parallel_group_gloo( + with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) group_size = dist.get_world_size(group=process_group) if group_size > 1: output_objects = [None for _ in range(group_size)]