Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cookbook/transformers/fsdp2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
from twinkle.utils import is_torch_npu_available
from twinkle.kernel import apply_npu_patch

# Construct a device_mesh, fsdp=4, dp=2
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
Expand All @@ -17,6 +19,11 @@
logger = get_logger()


# npu patch
if is_torch_npu_available():
apply_npu_patch()


def eval(model):
# 100 Samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
Expand Down
6 changes: 6 additions & 0 deletions cookbook/transformers/fsdp2_moe_npu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

# CANN loading
source /usr/local/Ascend/ascend-toolkit/set_env.sh

ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py
1 change: 1 addition & 0 deletions src/twinkle/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .function import apply_function_kernel, register_function_kernel
from .layer import apply_layer_kernel, register_layer_batch, register_layer_kernel
from .registry import register_external_layer as _register_external_layer
from .monkey_patch_npu import apply_npu_patch

logger = getLogger(__name__)

Expand Down
86 changes: 86 additions & 0 deletions src/twinkle/kernel/monkey_patch_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import functools
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to VERL, aggregate all patches in this file, and perform full replacement externally only after is_npu_available() at twinkle initialization.

import torch
import torch_npu

class GmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor):
assert x.dim() == 2, f"x must be [M, K], got {tuple(x.shape)}"
assert group_list.dim() == 1, f"group_list must be [E], got {tuple(group_list.shape)}"
assert weight_ekn.dim() == 3, f"weight_ekn must be [E, K, N], got {tuple(weight_ekn.shape)}"
assert group_list.numel() == weight_ekn.size(0), (
f"group_list len {group_list.numel()} != num_experts {weight_ekn.size(0)}"
)
assert x.size(1) == weight_ekn.size(1), (
f"input dim mismatch: x.shape={tuple(x.shape)}, weight_ekn.shape={tuple(weight_ekn.shape)}"
)

group_list = group_list.to(torch.int64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conversion to torch.int64 is redundant because the group_list (passed as counts from _grouped_mm_npu) has already been converted to torch.int64 at line 106.


ctx.save_for_backward(x, group_list, weight_ekn)

outputs = torch_npu.npu_grouped_matmul(
[x],
[weight_ekn],
group_list=group_list,
group_type=0,
split_item=2,
group_list_type=1,
)
return outputs[0]

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
x, group_list, weight_ekn = ctx.saved_tensors

grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
[weight_ekn.transpose(-2, -1).contiguous()],
bias=None,
group_list=group_list,
group_type=0,
split_item=2,
group_list_type=1,
)[0]

grad_weight = torch_npu.npu_grouped_matmul(
[x.transpose(0, 1)],
[grad_output],
bias=None,
group_list=group_list,
group_type=2,
split_item=3,
group_list_type=1,
)[0]

return grad_input, None, grad_weight.contiguous()


def _grouped_mm_npu(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:
assert input.dim() == 2, f"input must be [M, K], got {tuple(input.shape)}"
assert weight_ekn.dim() == 3, f"weight_ekn must be [E, K, N], got {tuple(weight_ekn.shape)}"
assert offs.dim() == 1, f"offs must be [E], got {tuple(offs.shape)}"
assert weight_ekn.size(0) == offs.numel(), (
f"weight_ekn.size(0)={weight_ekn.size(0)} != offs.numel()={offs.numel()}"
)

counts = torch.empty_like(offs)
counts[0] = offs[0]
if offs.numel() > 1:
counts[1:] = offs[1:] - offs[:-1]
Comment on lines +67 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of group counts from cumulative offsets can be simplified using clone and in-place subtraction, which is more concise and idiomatic.

Suggested change
counts = torch.empty_like(offs)
counts[0] = offs[0]
if offs.numel() > 1:
counts[1:] = offs[1:] - offs[:-1]
counts = offs.clone()
counts[1:] -= offs[:-1]

counts = counts.to(torch.int64)

return GmmFunction.apply(input, counts, weight_ekn)


def apply_hf_moe_grouped_mm_patch():
import transformers.integrations.moe as hf_moe

hf_moe._grouped_mm = _grouped_mm_npu
print("[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu")

def apply_npu_patch():
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
apply_hf_moe_grouped_mm_patch()
2 changes: 1 addition & 1 deletion src/twinkle/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
stateless_init_process_group, to_device)
from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert
from .unsafe import check_unsafe, trust_remote_code
from .utils import copy_files_by_pattern, deep_getattr
from .utils import copy_files_by_pattern, deep_getattr, is_torch_npu_available
from .vision_tools import load_image, load_mm_file
8 changes: 8 additions & 0 deletions src/twinkle/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,11 @@ def should_exclude_file(file_path, file_name):
destination = os.path.join(dest_dir, file_name)
if not os.path.exists(destination):
shutil.copy2(file_path, destination)

def is_torch_npu_available():
try:
import torch
import torch_npu # noqa: F401
return hasattr(torch, "npu") and torch.npu.is_available()
except Exception:
return False