diff --git a/cookbook/transformers/fsdp2_moe.py b/cookbook/transformers/fsdp2_moe.py index 3ea649d3..cdc24ad6 100644 --- a/cookbook/transformers/fsdp2_moe.py +++ b/cookbook/transformers/fsdp2_moe.py @@ -8,6 +8,8 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils import is_torch_npu_available +from twinkle.kernel import apply_npu_patch # Construct a device_mesh, fsdp=4, dp=2 device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) @@ -17,6 +19,11 @@ logger = get_logger() +# npu patch +if is_torch_npu_available(): + apply_npu_patch() + + def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) diff --git a/cookbook/transformers/fsdp2_moe_npu.sh b/cookbook/transformers/fsdp2_moe_npu.sh new file mode 100644 index 00000000..349f9d0d --- /dev/null +++ b/cookbook/transformers/fsdp2_moe_npu.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +# CANN loading +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index fb07ba03..8de317bf 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -7,6 +7,7 @@ from .function import apply_function_kernel, register_function_kernel from .layer import apply_layer_kernel, register_layer_batch, register_layer_kernel from .registry import register_external_layer as _register_external_layer +from .monkey_patch_npu import apply_npu_patch logger = getLogger(__name__) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py new file mode 100644 index 00000000..cff8ebc7 --- /dev/null +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -0,0 +1,86 @@ +import functools +import torch +import torch_npu + +class GmmFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): + assert x.dim() == 2, f"x must be [M, K], got {tuple(x.shape)}" + assert group_list.dim() == 1, f"group_list must be [E], got {tuple(group_list.shape)}" + assert weight_ekn.dim() == 3, f"weight_ekn must be [E, K, N], got {tuple(weight_ekn.shape)}" + assert group_list.numel() == weight_ekn.size(0), ( + f"group_list len {group_list.numel()} != num_experts {weight_ekn.size(0)}" + ) + assert x.size(1) == weight_ekn.size(1), ( + f"input dim mismatch: x.shape={tuple(x.shape)}, weight_ekn.shape={tuple(weight_ekn.shape)}" + ) + + group_list = group_list.to(torch.int64) + + ctx.save_for_backward(x, group_list, weight_ekn) + + outputs = torch_npu.npu_grouped_matmul( + [x], + [weight_ekn], + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, + ) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + x, group_list, weight_ekn = ctx.saved_tensors + + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], + [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, + )[0] + + grad_weight = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], + [grad_output], + bias=None, + group_list=group_list, + group_type=2, + split_item=3, + group_list_type=1, + )[0] + + return grad_input, None, grad_weight.contiguous() + + +def _grouped_mm_npu(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + assert input.dim() == 2, f"input must be [M, K], got {tuple(input.shape)}" + assert weight_ekn.dim() == 3, f"weight_ekn must be [E, K, N], got {tuple(weight_ekn.shape)}" + assert offs.dim() == 1, f"offs must be [E], got {tuple(offs.shape)}" + assert weight_ekn.size(0) == offs.numel(), ( + f"weight_ekn.size(0)={weight_ekn.size(0)} != offs.numel()={offs.numel()}" + ) + + counts = torch.empty_like(offs) + counts[0] = offs[0] + if offs.numel() > 1: + counts[1:] = offs[1:] - offs[:-1] + counts = counts.to(torch.int64) + + return GmmFunction.apply(input, counts, weight_ekn) + + +def apply_hf_moe_grouped_mm_patch(): + import transformers.integrations.moe as hf_moe + + hf_moe._grouped_mm = _grouped_mm_npu + print("[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu") + +def apply_npu_patch(): + import torch + import torch_npu + from torch_npu.contrib import transfer_to_npu + apply_hf_moe_grouped_mm_patch() \ No newline at end of file diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index cca7e63b..b852fa99 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -14,5 +14,5 @@ stateless_init_process_group, to_device) from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code -from .utils import copy_files_by_pattern, deep_getattr +from .utils import copy_files_by_pattern, deep_getattr, is_torch_npu_available from .vision_tools import load_image, load_mm_file diff --git a/src/twinkle/utils/utils.py b/src/twinkle/utils/utils.py index 0b0ae4d0..610489eb 100644 --- a/src/twinkle/utils/utils.py +++ b/src/twinkle/utils/utils.py @@ -77,3 +77,11 @@ def should_exclude_file(file_path, file_name): destination = os.path.join(dest_dir, file_name) if not os.path.exists(destination): shutil.copy2(file_path, destination) + +def is_torch_npu_available(): + try: + import torch + import torch_npu # noqa: F401 + return hasattr(torch, "npu") and torch.npu.is_available() + except Exception: + return False \ No newline at end of file