From ddd8f40ff5304ef934c465f7cf58f5c39ba7f6a9 Mon Sep 17 00:00:00 2001 From: zhaoshijie Date: Thu, 29 Jan 2026 21:04:44 +0800 Subject: [PATCH 1/2] complete embedding/hardsigmoid/hardshrink/argmin/cosine_embedding_loss kernel and test --- src/ntops/kernels/__init__.py | 10 + src/ntops/kernels/argmin.py | 273 +++++++++++++++++++++ src/ntops/kernels/cosine_embedding_loss.py | 133 ++++++++++ src/ntops/kernels/embedding.py | 138 +++++++++++ src/ntops/kernels/hardshrink.py | 30 +++ src/ntops/kernels/hardsigmoid.py | 43 ++++ src/ntops/torch/__init__.py | 10 + src/ntops/torch/argmin.py | 53 ++++ src/ntops/torch/cosine_embedding_loss.py | 69 ++++++ src/ntops/torch/embedding.py | 228 +++++++++++++++++ src/ntops/torch/hardshrink.py | 21 ++ src/ntops/torch/hardsigmoid.py | 21 ++ tests/test_argmin.py | 40 +++ tests/test_cosine_embedding_loss.py | 67 +++++ tests/test_embedding.py | 32 +++ tests/test_hardshrink.py | 24 ++ tests/test_hardsigmoid.py | 31 +++ 17 files changed, 1223 insertions(+) create mode 100644 src/ntops/kernels/argmin.py create mode 100644 src/ntops/kernels/cosine_embedding_loss.py create mode 100644 src/ntops/kernels/embedding.py create mode 100644 src/ntops/kernels/hardshrink.py create mode 100644 src/ntops/kernels/hardsigmoid.py create mode 100644 src/ntops/torch/argmin.py create mode 100644 src/ntops/torch/cosine_embedding_loss.py create mode 100644 src/ntops/torch/embedding.py create mode 100644 src/ntops/torch/hardshrink.py create mode 100644 src/ntops/torch/hardsigmoid.py create mode 100644 tests/test_argmin.py create mode 100644 tests/test_cosine_embedding_loss.py create mode 100644 tests/test_embedding.py create mode 100644 tests/test_hardshrink.py create mode 100644 tests/test_hardsigmoid.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..e4b8cb5 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -39,6 +39,11 @@ softmax, sub, tanh, + embedding, + cosine_embedding_loss, + hardshrink, + argmin, + hardsigmoid, ) __all__ = [ @@ -82,4 +87,9 @@ "softmax", "sub", "tanh", + "embedding", + "cosine_embedding_loss", + "hardshrink", + "argmin", + "hardsigmoid", ] diff --git a/src/ntops/kernels/argmin.py b/src/ntops/kernels/argmin.py new file mode 100644 index 0000000..1d25c67 --- /dev/null +++ b/src/ntops/kernels/argmin.py @@ -0,0 +1,273 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement( + input, + output, + target_dim_size, + in_dims=None, + out_dims=None, + axis=None, + axis_is_none=False, + target_dim_size_power2=None, +): + if axis_is_none: + input = input.flatten() + input_arranged = input.tile((target_dim_size_power2,)) + if out_dims > 0: + output = output.flatten() + output_arranged = output.tile((-1,)) + else: + output_arranged = output.unsqueeze(dim=0) + output_arranged = output_arranged.tile((1,)) + else: + if in_dims == 4: + if axis == 0: + input_arranged = input.tile((target_dim_size_power2, -1, -1, -1)) + elif axis == 1: + input_arranged = input.tile((-1, target_dim_size_power2, -1, -1)) + elif axis == 2: + input_arranged = input.tile((-1, -1, target_dim_size_power2, -1)) + elif axis == 3: + input_arranged = input.tile((-1, -1, -1, target_dim_size_power2)) + input_arranged = input_arranged.squeeze(0) + input_arranged = input_arranged.squeeze(0) + input_arranged = input_arranged.squeeze(0) + elif in_dims == 3: + if axis == 0: + input_arranged = input.tile((target_dim_size_power2, -1, -1)) + elif axis == 1: + input_arranged = input.tile((-1, target_dim_size_power2, -1)) + elif axis == 2: + input_arranged = input.tile((-1, -1, target_dim_size_power2)) + input_arranged = input_arranged.squeeze(0) + input_arranged = input_arranged.squeeze(0) + elif in_dims == 2: + if axis == 0: + input_arranged = input.tile((target_dim_size_power2, -1)) + elif axis == 1: + input_arranged = input.tile((-1, target_dim_size_power2)) + input_arranged = input_arranged.squeeze(0) + elif in_dims == 1: + input_arranged = input.tile((target_dim_size_power2,)) + + if out_dims == 4: + output_arranged = output.tile((-1, -1, -1, -1)) + output_arranged = output_arranged.squeeze(0) + output_arranged = output_arranged.squeeze(0) + output_arranged = output_arranged.squeeze(0) + elif out_dims == 3: + output_arranged = output.tile((-1, -1, -1)) + output_arranged = output_arranged.squeeze(0) + output_arranged = output_arranged.squeeze(0) + elif out_dims == 2: + output_arranged = output.tile((-1, -1)) + output_arranged = output_arranged.squeeze(0) + elif out_dims == 1: + output_arranged = output.tile((-1,)) + elif out_dims == 0: + output_arranged = output.unsqueeze(dim=0) + output_arranged = output_arranged.tile((1,)) + + return input_arranged, output_arranged, target_dim_size + + +def application_0_1(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[0]) < target_dim_size + if input.ndim == 1: + mask = valid_mask + elif input.ndim == 2: + mask = valid_mask[:, None] + elif input.ndim == 3: + mask = valid_mask[:, None, None] + elif input.ndim == 4: + mask = valid_mask[:, None, None, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=0, keep_dims=1) + + +def application_1_1(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[1]) < target_dim_size + if input.ndim == 2: + mask = valid_mask[None, :] + elif input.ndim == 3: + mask = valid_mask[None, :, None] + elif input.ndim == 4: + mask = valid_mask[None, :, None, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=1, keep_dims=1) + + +def application_2_1(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[2]) < target_dim_size + if input.ndim == 3: + mask = valid_mask[None, None, :] + elif input.ndim == 4: + mask = valid_mask[None, None, :, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=2, keep_dims=1) + + +def application_3_1(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[3]) < target_dim_size + if input.ndim == 4: + mask = valid_mask[None, None, None, :] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=3, keep_dims=1) + + +def application_0_0(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[0]) < target_dim_size + if input.ndim == 1: + mask = valid_mask + elif input.ndim == 2: + mask = valid_mask[:, None] + elif input.ndim == 3: + mask = valid_mask[:, None, None] + elif input.ndim == 4: + mask = valid_mask[:, None, None, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=0, keep_dims=0) + + +def application_1_0(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[1]) < target_dim_size + if input.ndim == 2: + mask = valid_mask[None, :] + elif input.ndim == 3: + mask = valid_mask[None, :, None] + elif input.ndim == 4: + mask = valid_mask[None, :, None, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=1, keep_dims=0) + + +def application_2_0(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[2]) < target_dim_size + if input.ndim == 3: + mask = valid_mask[None, None, :] + elif input.ndim == 4: + mask = valid_mask[None, None, :, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=2, keep_dims=0) + + +def application_3_0(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[3]) < target_dim_size + if input.ndim == 4: + mask = valid_mask[None, None, None, :] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + output = ntl.argmin(masked_input, axis=3, keep_dims=0) + + +def application_0_0_scalar(input, output, target_dim_size): + valid_mask = ntl.arange(0, input.shape[0]) < target_dim_size + if input.ndim == 1: + mask = valid_mask + elif input.ndim == 2: + mask = valid_mask[:, None] + elif input.ndim == 3: + mask = valid_mask[:, None, None] + elif input.ndim == 4: + mask = valid_mask[:, None, None, None] + else: + raise ValueError(f"Unsupported input ndim: {input.ndim}") + masked_input = ntl.where(mask, input, float("inf")) + result = ntl.argmin(masked_input, axis=0, keep_dims=0) + output[0] = result + + +def premake( + target_dim_size, + dtype=None, + in_dims=None, + out_dims=None, + axis=None, + axis_is_none=False, + keep_dims=None, +): + import math + + target_dim_size_power2 = ( + 2 ** math.ceil(math.log2(int(target_dim_size))) + if int(target_dim_size) > 0 + else 1 + ) + arrangement_ = functools.partial( + arrangement, + in_dims=in_dims, + out_dims=out_dims, + axis=axis, + axis_is_none=axis_is_none, + target_dim_size_power2=target_dim_size_power2, + ) + + tensors = ( + Tensor(in_dims, dtype=dtype, shape_options={"constexpr": True}), + Tensor(out_dims, dtype=ninetoothed.int64, shape_options={"constexpr": True}), + Tensor(0, constexpr=True, value=target_dim_size), + ) + # if out_dims > 0: + # tensors = ( + # Tensor(in_dims, dtype=dtype, shape_options={"constexpr": True}), + # Tensor(out_dims, dtype=ninetoothed.int64, shape_options={"constexpr": True}), + # ) + # else: + # # 标量输出:不使用 constexpr + # tensors = ( + # Tensor(in_dims, dtype=dtype, shape_options={"constexpr": True}), + # Tensor(0, dtype=ninetoothed.int64), # ← 删除 shape_options + # ) + + if axis_is_none: + if keep_dims: + application = application_0_1 + else: + application = application_0_0 + else: + if axis == 0: + if keep_dims: + application = application_0_1 + else: + application = application_0_0 + elif axis == 1: + if keep_dims: + application = application_1_1 + else: + application = application_1_0 + elif axis == 2: + if keep_dims: + application = application_2_1 + else: + application = application_2_0 + elif axis == 3: + if keep_dims: + application = application_3_1 + else: + application = application_3_0 + else: + raise ValueError(f"Unsupported axis: {axis}") + + if out_dims == 0: + application = application_0_0_scalar + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/cosine_embedding_loss.py b/src/ntops/kernels/cosine_embedding_loss.py new file mode 100644 index 0000000..8b461db --- /dev/null +++ b/src/ntops/kernels/cosine_embedding_loss.py @@ -0,0 +1,133 @@ +import enum +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +BLOCK_SIZE = ninetoothed.block_size() + + +def cosine_embedding_loss_arrangement( + x1, x2, y, margin, output, block_size=None, dims=None, embedding_dim=None +): + """用于计算余弦相似度的arrangement""" + if block_size is None: + block_size = BLOCK_SIZE + + if dims == 1: + x1 = x1.unsqueeze(0) + x2 = x2.unsqueeze(0) + y = y.unsqueeze(0) + output = output.unsqueeze(0) + + x1_arranged = x1.flatten(start_dim=0, end_dim=-1) + x1_arranged = x1_arranged.tile((1, embedding_dim)) + x1_arranged = x1_arranged.squeeze(1) + x1_arranged.dtype = x1_arranged.dtype.squeeze(0) + x1_arranged = x1_arranged.tile((1,)) + + x2_arranged = x2.flatten(start_dim=0, end_dim=-1) + x2_arranged = x2_arranged.tile((1, embedding_dim)) + x2_arranged = x2_arranged.squeeze(1) + x2_arranged.dtype = x2_arranged.dtype.squeeze(0) + x2_arranged = x2_arranged.tile((1,)) + + y_arranged = y.flatten() + y_arranged = y_arranged.tile((1,)) + + # margin_arranged = margin + + output_arranged = output.flatten() + output_arranged = output_arranged.tile((1,)) + + return x1_arranged, x2_arranged, y_arranged, margin, output_arranged + + +def cosine_embedding_loss_application(x1, x2, y, margin, output): + """计算余弦相似度""" + dot_product = 0.0 + norm1_sq = 0.0 + norm2_sq = 0.0 + dot_product += ntl.sum(x1[0] * x2[0]) + norm1_sq += ntl.sum(x1[0] * x1[0]) + norm2_sq += ntl.sum(x2[0] * x2[0]) + norm1 = ntl.sqrt(norm1_sq) + norm2 = ntl.sqrt(norm2_sq) + + cosine = dot_product / (norm1 * norm2) + + if y[0] > 0: + output[0] = 1.0 - cosine + else: + diff = cosine - margin + output[0] = ntl.maximum(0.0, diff) + + +def cosine_embedding_loss_premake( + dtype=None, + block_size=None, + dims=None, + embedding_dim=None, +): + import math + + embedding_dim_power_of_2 = ( + 2 ** math.ceil(math.log2(embedding_dim)) if embedding_dim > 0 else 1 + ) + arrangement_ = functools.partial( + cosine_embedding_loss_arrangement, + block_size=block_size, + dims=dims, + embedding_dim=embedding_dim_power_of_2, + ) + + tensors = ( + Tensor(dims, dtype=dtype), + Tensor(dims, dtype=dtype), + Tensor(dims - 1, dtype=ninetoothed.int32), + Tensor(0, dtype=ninetoothed.float32), + Tensor(dims - 1, dtype=dtype), + ) + + return arrangement_, cosine_embedding_loss_application, tensors + + +def arrangement_all_elements(input, output, block_size=None): + input = input.flatten().tile((block_size,)) + output = output.tile((1,)) + return input, output + + +def application_all_elements(input, output): + output[0] = ntl.sum(input, 0) + + +def reduce_sum_premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(1, dtype=dtype), + ) + + return arrangement_, application_all_elements, tensors + + +from ntops.kernels.element_wise import arrangement + + +def div_application(input, other, output): + output = input / other # noqa: F841 + + +def div_premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.int32), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, div_application, tensors diff --git a/src/ntops/kernels/embedding.py b/src/ntops/kernels/embedding.py new file mode 100644 index 0000000..1c15488 --- /dev/null +++ b/src/ntops/kernels/embedding.py @@ -0,0 +1,138 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +BLOCK_SIZE = ninetoothed.block_size() + + +def arrangement( + input, + weight, + output, + max_norm, + norm_type, + block_size_m=None, + block_size_n=None, + embedding_dim=None, +): + if block_size_m is None: + block_size_m = BLOCK_SIZE + if block_size_n is None: + block_size_n = BLOCK_SIZE + + output = output.flatten(start_dim=0, end_dim=-1) + input = input.flatten() + + output_arranged = output.tile((1, embedding_dim)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + output_arranged = output_arranged.tile((1, 1)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + output_arranged = output_arranged.squeeze(1) + + input_arranged = input.tile((1,)) + + weight_arranged = weight.tile((1, embedding_dim)) + weight_arranged.dtype = weight_arranged.dtype.squeeze(0) + weight_arranged = weight_arranged.squeeze(1) + weight_arranged = weight_arranged.tile((-1,)) + weight_arranged = weight_arranged.expand((output_arranged.shape[0],)) + + return input_arranged, weight_arranged, output_arranged, max_norm, norm_type + + +def application(input, weight, output, max_norm, norm_type): + idx = input[0] + tmp = ntl.zeros(weight[0].shape, dtype=ntl.float32) + tmp = ntl.abs(weight[idx]) + tmp = libdevice.pow(tmp, norm_type) + sum = ntl.sum(tmp) + norm = libdevice.pow(sum, 1.0 / norm_type) + + if norm > max_norm: + scale = max_norm / norm + weight[idx] = weight[idx] * scale + output[0] = weight[idx] + + +def premake(ndim, embedding_dim=None, dtype=None, block_size_m=None, block_size_n=None): + import math + + embedding_dim_power_of_2 = ( + 2 ** math.ceil(math.log2(embedding_dim)) if embedding_dim > 0 else 1 + ) + arrangement_ = functools.partial( + arrangement, + block_size_m=block_size_m, + block_size_n=block_size_n, + embedding_dim=embedding_dim_power_of_2, + ) + + tensors = ( + Tensor(ndim, dtype=ninetoothed.int64), + Tensor(2, dtype=dtype), + Tensor(ndim + 1, dtype=dtype), + Tensor(0, dtype=ninetoothed.float32), + Tensor(0, dtype=ninetoothed.float32), + ) + + return arrangement_, application, tensors + + +def arrangement_without_norm( + input, + weight, + output, + block_size_m=None, + block_size_n=None, +): + if block_size_m is None: + block_size_m = BLOCK_SIZE + if block_size_n is None: + block_size_n = BLOCK_SIZE + + output = output.flatten(start_dim=0, end_dim=-1) + input = input.flatten() + # below commetned out code can be ussed for optional params == None + # output: [N, embedding_dim] -> tile to [block_size_m, block_size_n] + output_arranged = output.tile((1, block_size_n)) + output_arranged = output_arranged.tile((block_size_m, 1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + # input: [N] -> tile to [block_size_m], 这是索引 + input = input.unsqueeze(1) # 增加一个维度以便进行expand + input_arranged = input.tile((block_size_m, 1)) + input_arranged = input_arranged.expand((-1, output_arranged.shape[1])) + input_arranged.dtype = input_arranged.dtype.squeeze(1) + + # weight: [vocab_size, embedding_dim] -> tile embedding_dim dimension + # 然后 expand 到匹配 output 的 batch 维度 + weight_arranged = weight.tile((1, block_size_n)) + weight_arranged = weight_arranged.tile((-1, 1)) # 保持所有 vocab 行 + weight_arranged = weight_arranged.expand((output_arranged.shape[0], -1)) + weight_arranged.dtype = weight_arranged.dtype.squeeze(1) + + return input_arranged, weight_arranged, output_arranged + + +def application_without_norm(input, weight, output): + for i in range(output.shape[0]): + idx = input[i] + output[i] = weight[idx] + + +def premake_without_norm(ndim, dtype=None, block_size_m=None, block_size_n=None): + arrangement_ = functools.partial( + arrangement_without_norm, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + tensors = ( + Tensor(ndim, dtype=ninetoothed.int64), + Tensor(2, dtype=dtype), + Tensor(ndim + 1, dtype=dtype), + ) + + return arrangement_, application_without_norm, tensors diff --git a/src/ntops/kernels/hardshrink.py b/src/ntops/kernels/hardshrink.py new file mode 100644 index 0000000..c4cd647 --- /dev/null +++ b/src/ntops/kernels/hardshrink.py @@ -0,0 +1,30 @@ +import enum +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, lambd, output): + output = ntl.where(ntl.abs(input) > lambd, input, 0.0) + + +def hardshrink_premake( + ndim, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float32), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/hardsigmoid.py b/src/ntops/kernels/hardsigmoid.py new file mode 100644 index 0000000..ed63ff7 --- /dev/null +++ b/src/ntops/kernels/hardsigmoid.py @@ -0,0 +1,43 @@ +import enum +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application_default(input, output): + output = ntl.where(input > -3 and input < 3, input / 6 + 0.5, input) + output = ntl.where(output <= -3, 0, output) + output = ntl.where(output >= 3, 1, output) + + +def application_inplace(input): + input = ntl.where(input > -3 and input < 3, input / 6 + 0.5, input) + input = ntl.where(input <= -3, 0, input) + input = ntl.where(input >= 3, 1, input) + + +def premake( + ndim, + inplace=False, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + if inplace: + tensors = (Tensor(ndim, dtype=dtype),) + else: + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + application = application_inplace if inplace else application_default + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..eab7b72 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -39,6 +39,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.embedding import embedding +from ntops.torch.cosine_embedding_loss import cosine_embedding_loss +from ntops.torch.hardshrink import hardshrink +from ntops.torch.argmin import argmin +from ntops.torch.hardsigmoid import hardsigmoid __all__ = [ "abs", @@ -82,4 +87,9 @@ "softmax", "sub", "tanh", + "embedding", + "cosine_embedding_loss", + "hardshrink", + "argmin", + "hardsigmoid", ] diff --git a/src/ntops/torch/argmin.py b/src/ntops/torch/argmin.py new file mode 100644 index 0000000..e8b3504 --- /dev/null +++ b/src/ntops/torch/argmin.py @@ -0,0 +1,53 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + + +def argmin(input, axis=None, keepdim=False): + if axis is None: + target_dim = 0 + axis_is_none = True + else: + target_dim = axis + axis_is_none = False + + if axis is None: + axis = tuple(range(len(input.shape))) + elif isinstance(axis, int): + axis = (axis,) + else: + axis = tuple(axis) + + output_shape = list(input.shape) + for ax in axis: + output_shape[ax] = 1 if keepdim else 0 + output_shape = tuple([s for s in output_shape if s > 0]) + + output = torch.empty(output_shape, dtype=torch.int64, device=input.device) + + # print(f"input shape: {input.shape}, output shape: {output.shape}, axis: {axis}, keepdim: {keepdim}") + # print(f"input dtype: {input.dtype}, output dtype: {output.dtype}, output dim: {output.dim()}") + # print(f"target_dim: {target_dim}, axis_is_none: {axis_is_none}") + # 计算input含有的元素数量,通过乘积计算得到 + num_elements = 1 + for s in input.shape: + num_elements *= s + + kernel = _cached_make( + ntops.kernels.argmin.premake, + input.shape[target_dim] if not axis_is_none else num_elements, + dtype=input.dtype, + in_dims=len(input.shape), + out_dims=len(output.shape), + axis=target_dim, + axis_is_none=axis_is_none, + keep_dims=keepdim, + ) + + kernel( + input, + output, + input.shape[target_dim] if not axis_is_none else num_elements, + ) + + return output diff --git a/src/ntops/torch/cosine_embedding_loss.py b/src/ntops/torch/cosine_embedding_loss.py new file mode 100644 index 0000000..3d21a93 --- /dev/null +++ b/src/ntops/torch/cosine_embedding_loss.py @@ -0,0 +1,69 @@ +import torch +import math +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + + if target_size > 1024: + target_size = 1024 + + if target_size < 32: + target_size = 32 + + return target_size + + +def cosine_embedding_loss(x1, x2, y, margin=0.0, reduction="mean"): + embedding_dim = x1.shape[-1] + dims = len(x1.shape) + xshape = list(x1.shape) + output = torch.empty(xshape[:-1], dtype=x1.dtype, device=x1.device) + kernel_loss = _cached_make( + ntops.kernels.cosine_embedding_loss.cosine_embedding_loss_premake, + dims=dims, + embedding_dim=embedding_dim, + block_size=16, + ) + kernel_loss(x1, x2, y, margin, output) + + if reduction == "none": + return output + else: + cur = output + block_size = get_optimal_block_size(cur.numel()) + while cur.numel() > 1: + cur_output_len = math.ceil(cur.numel() / block_size) + cur_output = torch.empty( + (cur_output_len,), dtype=cur.dtype, device=cur.device + ) + kernel_sum = _cached_make( + ntops.kernels.cosine_embedding_loss.reduce_sum_premake, + cur.ndim, + cur.dtype, + block_size, + ) + kernel_sum(cur, cur_output) + cur = cur_output + res = cur.view(()) + res = cur.unsqueeze(0) + if reduction == "mean": + mean_out = torch.empty_like(res) + kernel_div = _cached_make( + ntops.kernels.cosine_embedding_loss.div_premake, res.ndim + ) + other = output.numel() + kernel_div(res, other, mean_out) + mean_out = mean_out.view(()) + return mean_out + else: + res = res.view(()) + return res diff --git a/src/ntops/torch/embedding.py b/src/ntops/torch/embedding.py new file mode 100644 index 0000000..251d68f --- /dev/null +++ b/src/ntops/torch/embedding.py @@ -0,0 +1,228 @@ +import torch +import numpy as np +import ctypes +import ntops +from ntops.torch.utils import _cached_make + + +def _get_dtype_info(dtype): + dtype_str = str(dtype) + for prefix in ("infinicore.", "torch."): + if dtype_str.startswith(prefix): + dtype_str = dtype_str[len(prefix) :] + break + info = _DTYPE_MAP.get(dtype_str) + if info is None: + raise RuntimeError(f"Unsupported dtype: {dtype}") + return info + + +def _is_cpu(tensor): + device_str = str(tensor.device).lower() + return "cpu" in device_str + + +def _get_cpu_device(): + return torch.device("cpu", 0) + + +_DTYPE_MAP = { + "float16": (np.float16, 2), + "float32": (np.float32, 4), + "float64": (np.float64, 8), + "bfloat16": (None, 2), + "int8": (np.int8, 1), + "int16": (np.int16, 2), + "int32": (np.int32, 4), + "int64": (np.int64, 8), + "uint8": (np.uint8, 1), +} + + +def _bf16_to_fp32(bf16_uint16_arr): + u32 = bf16_uint16_arr.astype(np.uint32) << 16 + return u32.view(np.float32) + + +def _fp32_to_bf16(fp32_arr): + u32 = fp32_arr.view(np.uint32) + u16 = (u32 >> 16).astype(np.uint16) + return u16 + + +def _tensor_to_numpy(tensor): + tensor = tensor.contiguous() + _, elem_bytes = _get_dtype_info(tensor.dtype) + numel = tensor.numel() + total_bytes = numel * elem_bytes + shape = list(tensor.shape) + + dtype_str = str(tensor.dtype) + for prefix in ("infinicore.", "torch."): + if dtype_str.startswith(prefix): + dtype_str = dtype_str[len(prefix) :] + break + is_bf16 = dtype_str == "bfloat16" + np_dtype = _DTYPE_MAP[dtype_str][0] + + if _is_cpu(tensor): + ptr = tensor.data_ptr() + buf = (ctypes.c_byte * total_bytes).from_address(ptr) + if is_bf16: + raw = np.frombuffer(buf, dtype=np.uint16).copy() + np_arr = _bf16_to_fp32(raw) + else: + np_arr = np.frombuffer(buf, dtype=np_dtype).copy() + return np_arr.reshape(shape) + else: + cpu_dev = _get_cpu_device() + cpu_tensor = torch.empty(shape, dtype=tensor.dtype, device=cpu_dev) + cpu_tensor.copy_(tensor) + ptr = cpu_tensor.data_ptr() + buf = (ctypes.c_byte * total_bytes).from_address(ptr) + if is_bf16: + raw = np.frombuffer(buf, dtype=np.uint16).copy() + np_arr = _bf16_to_fp32(raw) + else: + np_arr = np.frombuffer(buf, dtype=np_dtype).copy() + return np_arr.reshape(shape) + + +def _numpy_to_tensor(np_arr, device, orig_dtype=None): + dtype_str = "" + if orig_dtype is not None: + dtype_str = str(orig_dtype) + for prefix in ("infinicore.", "torch."): + if dtype_str.startswith(prefix): + dtype_str = dtype_str[len(prefix) :] + break + + shape = list(np_arr.shape) + cpu_dev = _get_cpu_device() + + if dtype_str == "bfloat16": + bf16_raw = _fp32_to_bf16(np_arr.astype(np.float32)) + cpu_result = torch.empty(shape, dtype=orig_dtype, device=cpu_dev) + ptr = cpu_result.data_ptr() + ctypes.memmove(ptr, bf16_raw.ctypes.data, bf16_raw.nbytes) + if _is_cpu_device(device): + return cpu_result + else: + result = torch.empty(shape, dtype=orig_dtype, device=device) + result.copy_(cpu_result) + return result + else: + cpu_tensor = torch.from_numpy(np_arr.copy()) + if _is_cpu_device(device): + return cpu_tensor + else: + return cpu_tensor.to(device) + + +def _is_cpu_device(device): + return "cpu" in str(device).lower() + + +def _unique(input, sorted=True, return_inverse=False, return_counts=False): + orig_device = input.device + np_input = _tensor_to_numpy(input).reshape(-1) + + np_results = np.unique( + np_input, + return_inverse=return_inverse, + return_counts=return_counts, + ) + + if return_inverse and return_counts: + np_unique, np_inverse, np_counts = np_results + elif return_inverse: + np_unique, np_inverse = np_results + elif return_counts: + np_unique, np_counts = np_results + else: + np_unique = np_results + + if not sorted: + first_indices = np.array( + [np.nonzero(np_input == val)[0][0] for val in np_unique] + ) + order = np.argsort(first_indices) + np_unique = np_unique[order] + if return_inverse: + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + np_inverse = inv_order[np_inverse] + if return_counts: + np_counts = np_counts[order] + + unique_tensor = _numpy_to_tensor(np_unique, orig_device) + + result = (unique_tensor,) + if return_inverse: + result += (_numpy_to_tensor(np_inverse.astype(np.int64), orig_device),) + if return_counts: + result += (_numpy_to_tensor(np_counts.astype(np.int64), orig_device),) + + return result if len(result) > 1 else result[0] + + +def _gather_by_indices(temp_out, inverse_indices, out_flat): + """ + temp_out: [M, D] tensor + inverse_indices: [N] tensor (int64) + out_flat: [N, D] tensor + """ + np_temp = _tensor_to_numpy(temp_out) # shape: [M, D] + np_inv = _tensor_to_numpy(inverse_indices) # shape: [N] + np_inv = np_inv.astype(np.int64) + np_result = np_temp[np_inv] # shape: [N, D] + result_tensor = _numpy_to_tensor( + np_result, out_flat.device, orig_dtype=out_flat.dtype + ) + out_flat.copy_(result_tensor) + + +def embedding(input, weight, out=None, max_norm=None, norm_type=2.0): + if out is None: + out_shape = list(input.shape) + [weight.shape[1]] + out = torch.empty(out_shape, dtype=weight.dtype, device=input.device) + + # kernel = _cached_make(ntops.kernels.embedding.premake, input.dim(), weight.shape[0], weight.shape[1], block_size_m=4, block_size_n=4) + # kernel(input, weight, out, max_norm, norm_type) + + # Find unique indices to reduce redundant computations, then map back + # to original output positions. This is especially beneficial when input + # contains many repeated indices. Otherwise, data races in parallelism will cause errors. + unique_indices, inverse_indices = _unique( + input.view([input.numel()]), return_inverse=True + ) + temp_out = torch.empty( + [unique_indices.shape[0], weight.shape[1]], + dtype=weight.dtype, + device=input.device, + ) + if max_norm is None: + kernel = _cached_make( + ntops.kernels.embedding.premake_without_norm, + len(unique_indices.shape), + dtype=weight.dtype, + block_size_m=4, + block_size_n=4, + ) + kernel(unique_indices, weight, temp_out) + else: + kernel = _cached_make( + ntops.kernels.embedding.premake, + len(unique_indices.shape), + embedding_dim=weight.shape[1], + dtype=weight.dtype, + block_size_m=4, + block_size_n=4, + ) + kernel(unique_indices, weight, temp_out, max_norm, norm_type) + + # out_flat = out.view(-1, weight.shape[1]) + out_flat = out.view([input.numel(), weight.shape[1]]) + # out_flat[:] = temp_out[inverse_indices] + _gather_by_indices(temp_out, inverse_indices, out_flat) + return out diff --git a/src/ntops/torch/hardshrink.py b/src/ntops/torch/hardshrink.py new file mode 100644 index 0000000..0665386 --- /dev/null +++ b/src/ntops/torch/hardshrink.py @@ -0,0 +1,21 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + + +def hardshrink(input, lambd=0.5): + if isinstance(lambd, torch.Tensor): + lambd = float(lambd.item()) + else: + lambd = float(lambd) + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.hardshrink.hardshrink_premake, + input.ndim, + block_size=1024, + ) + + kernel(input, lambd, output) + return output diff --git a/src/ntops/torch/hardsigmoid.py b/src/ntops/torch/hardsigmoid.py new file mode 100644 index 0000000..1539bf8 --- /dev/null +++ b/src/ntops/torch/hardsigmoid.py @@ -0,0 +1,21 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + + +def hardsigmoid(input, inplace=False): + kernel = _cached_make( + ntops.kernels.hardsigmoid.premake, + input.ndim, + inplace=inplace, + dtype=input.dtype, + block_size=1024, + ) + + if inplace: + kernel(input) + return input + else: + output = torch.empty_like(input) + kernel(input, output) + return output diff --git a/tests/test_argmin.py b/tests/test_argmin.py new file mode 100644 index 0000000..c96568a --- /dev/null +++ b/tests/test_argmin.py @@ -0,0 +1,40 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import _random_shape + + +def generate_argmin_arguments(): + """专门为 argmin 生成参数""" + arguments = [] + dtype_arr = (torch.float32, torch.float16) + + for ndim in range(1, 5): + for dtype in dtype_arr: + device = "cuda" + atol = 0.001 if dtype is torch.float32 else 0.01 + rtol = 0.001 if dtype is torch.float32 else 0.01 + + shape = _random_shape(ndim) + + # axis=None 的情况 + for keepdim in [False, True]: + arguments.append((shape, None, keepdim, dtype, device, rtol, atol)) + + # axis=0 到 ndim-1 的情况 + for axis in range(ndim): + for keepdim in [False, True]: + arguments.append((shape, axis, keepdim, dtype, device, rtol, atol)) + + return "shape, axis, keepdim, dtype, device, rtol, atol", arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_argmin_arguments()) +def test_argmin(shape, axis, keepdim, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + output = ntops.torch.argmin(input, axis=axis, keepdim=keepdim) + reference = torch.argmin(input, dim=axis, keepdim=keepdim) + torch.testing.assert_close(output, reference, rtol=rtol, atol=atol) diff --git a/tests/test_cosine_embedding_loss.py b/tests/test_cosine_embedding_loss.py new file mode 100644 index 0000000..94a195f --- /dev/null +++ b/tests/test_cosine_embedding_loss.py @@ -0,0 +1,67 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +def manual_cosine_embedding_loss(x1, x2, y, margin=0.0, reduction="mean"): + if x1.dim() == 1: + x1 = x1.unsqueeze(0) # (D,) -> (1, D) + x2 = x2.unsqueeze(0) + if y.dim() == 0: + y = y.unsqueeze(0) # () -> (1,) + + cosine = torch.nn.functional.cosine_similarity(x1, x2, dim=-1, eps=1e-8) + + loss = torch.where(y == 1, 1.0 - cosine, torch.clamp(cosine - margin, min=0.0)) + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + else: # 'none' + return loss + + +def generate_arguments(): + return "shape,dtype,device,rtol,atol", [ + ( + ( + 15, + 411, + ), + torch.float16, + "cuda", + 1e-3, + 1e-2, + ), + ] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cosine_embedding_loss(shape, dtype, device, rtol, atol): + if len(shape) > 2: + pytest.skip("Skipping test for tensors with more than 2 dimensions.") + else: + x1 = torch.randn(shape, dtype=dtype, device=device) + x2 = torch.randn(shape, dtype=dtype, device=device) + if len(shape) == 1: + y = torch.randint(-1, 2, (1,), device=device).float() + y[y == 0] = 1 + y = y.squeeze() + else: + y = torch.randint(-1, 2, shape[:-1], device=device).float() + y[y == 0] = 1 + margin = 0.5 + + manual_output = manual_cosine_embedding_loss( + x1.clone(), x2.clone(), y.clone(), margin=margin, reduction="mean" + ) + ninetoothed_output = ntops.torch.cosine_embedding_loss( + x1.clone(), x2.clone(), y.clone(), margin=margin, reduction="mean" + ) + # reference_output = torch.nn.functional.cosine_embedding_loss(x1, x2, y, margin=margin, reduction='mean') + assert torch.allclose(ninetoothed_output, manual_output, rtol=rtol, atol=atol) diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..083da48 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,32 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_embedding(shape, dtype, device, rtol, atol): + vocab_size = 10000 + embedding_dim = 2048 + + input = torch.randint(0, vocab_size, shape, device=device) + weight = torch.randn(vocab_size, embedding_dim, dtype=dtype, device=device) + + # among optional params, only max_norm and norm_type are supported, because only they affect output, others are for gradient calculation + ninetoothed_output = ntops.torch.embedding( + input, weight, max_norm=None, norm_type=1.1 + ) + reference_output = torch.nn.functional.embedding( + input, + weight, + padding_idx=None, + max_norm=None, + norm_type=1.1, + scale_grad_by_freq=False, + sparse=False, + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_hardshrink.py b/tests/test_hardshrink.py new file mode 100644 index 0000000..d30bd26 --- /dev/null +++ b/tests/test_hardshrink.py @@ -0,0 +1,24 @@ +import pytest +import torch +import random + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +# def generate_arguments(): +# return 'shape,dtype,device,rtol,atol', [ +# ((334,), torch.float16, 'cuda', 1e-2, 1e-2),] +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_hardshrink(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) * 10 + # lambd = torch.randn((), dtype=dtype, device=device).abs() + lambd = random.uniform(0.0, 2.0) + + output = ntops.torch.hardshrink(input, lambd) + + expected_output = torch.nn.functional.hardshrink(input, lambd) + + torch.testing.assert_close(output, expected_output, rtol=rtol, atol=atol) diff --git a/tests/test_hardsigmoid.py b/tests/test_hardsigmoid.py new file mode 100644 index 0000000..0563685 --- /dev/null +++ b/tests/test_hardsigmoid.py @@ -0,0 +1,31 @@ +import pytest +import torch +import random + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +m = torch.nn.Hardsigmoid() + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("inplace", [False, True], ids=["not_inplace", "inplace"]) +def test_hardsigmoid(shape, dtype, device, rtol, atol, inplace): + input = torch.randn(shape, dtype=dtype, device=device) + input_copy = input.clone() + input_id = id(input) + + output = ntops.torch.hardsigmoid(input, inplace=inplace) + + if inplace: + # 验证 inplace 返回同一个对象 + assert id(output) == input_id, "inplace=True should return the same tensor" + else: + # 验证非 inplace 不修改输入 + torch.testing.assert_close(input, input_copy, rtol=0, atol=0) + + # 验证结果正确 + reference = m(input_copy) + torch.testing.assert_close(output, reference, rtol=rtol, atol=atol) From c3f7ab43e8db887f6e85890da8c3dd376db75a16 Mon Sep 17 00:00:00 2001 From: jie Date: Fri, 13 Mar 2026 19:00:56 +0800 Subject: [PATCH 2/2] add matrix_exp --- src/ntops/torch/matrix_exp.py | 140 ++++++++++++++++++++++++++++++++++ tests/test_matrix_exp.py | 45 +++++++++++ 2 files changed, 185 insertions(+) create mode 100644 src/ntops/torch/matrix_exp.py create mode 100644 tests/test_matrix_exp.py diff --git a/src/ntops/torch/matrix_exp.py b/src/ntops/torch/matrix_exp.py new file mode 100644 index 0000000..2cced42 --- /dev/null +++ b/src/ntops/torch/matrix_exp.py @@ -0,0 +1,140 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +# def matrix_exp(input): +# output = torch.empty_like(input) + +# c = [ +# 1.0, +# 12.0, +# 66.0, +# 220.0, +# 495.0, +# 792.0, +# 924.0 +# ] + +# N_6 = torch.empty_like(input) +# temp_N_6 = torch.empty_like(input) +# D_6 = torch.empty_like(input) +# temp_D_6 = torch.empty_like(input) +# temp_up1 = torch.empty_like(input) +# temp_up2 = torch.empty_like(input) +# temp_down1 = torch.empty_like(input) +# temp_down2 = torch.empty_like(input) +# I = torch.eye(input.shape[-1], device=input.device, dtype=input.dtype).expand_as(input) + + +# mm_kernel = _cached_make(ntops.kernels.mm.premake) +# addmm_kernel = _cached_make(ntops.kernels.addmm.premake) + +# for i in range(7): +# if i == 0: +# N_6.copy_(I) +# D_6.copy_(I) +# temp_up1.copy_(I) +# temp_down1.copy_(I) +# else: +# mm_kernel(input, N_6, temp_N_6, _get_matmul_input_precision()) +# N_6 = temp_N_6 +# addmm_kernel(temp_up1, I, N_6, 1.0, c[i], temp_up2, _get_matmul_input_precision()) +# temp_up1 = temp_up2 + +# mm_kernel(input, D_6, temp_D_6, _get_matmul_input_precision()) +# D_6 = temp_D_6 +# addmm_kernel(temp_down1, I, D_6, 1.0, c[i] if i % 2 == 0 else -c[i], temp_down2, _get_matmul_input_precision()) +# temp_down1 = temp_down2 + +# # div_kernel = _cached_make(ntops.kernels.div.premake, input.ndim, None) +# # div_kernel(temp_up1, temp_down1, output) +# torch.linalg.solve(temp_down1, temp_up1, out=output) +# return output + +def matrix_exp(A): + """ + 计算矩阵指数 exp(A) 使用 10 阶泰勒级数 + 缩放-平方方法 + + exp(A) = I + A + A²/2! + A³/3! + ... + A¹⁰/10! + + 缩放: A_scaled = A / 2^s, 其中 s = max(0, ceil(log2(||A||/theta))) + 平方还原: exp(A) = (exp(A_scaled))^(2^s) + """ + original_dtype = A.dtype + device = A.device + + if A.dtype in (torch.float16, torch.bfloat16): + A = A.float() + + if A.ndim == 2: + if A.shape[0] != A.shape[1]: + raise RuntimeError(f"matrix_exp requires square matrix, got shape {A.shape}") + batch_mode = False + n = A.shape[0] + elif A.ndim == 3: + if A.shape[1] != A.shape[2]: + raise RuntimeError(f"matrix_exp requires square matrices in batch, got shape {A.shape}") + batch_mode = True + batch_size = A.shape[0] + n = A.shape[1] + else: + raise RuntimeError(f"matrix_exp expects 2D or 3D input, got ndim={A.ndim}") + + # ===== 缩放(reduce scaling norm) ===== + if batch_mode: + # 对每个 batch 计算范数 + norm_A = torch.linalg.norm(A, ord=1, dim=(1, 2)).max() # 取 batch 中的最大范数 + else: + norm_A = torch.linalg.norm(A, ord=1) + + theta = 2.2 # 10 阶泰勒的最优阈值 + norm_val = norm_A.item() + if norm_val <= theta: + s_val = 0 + else: + import math + s_val = max(0, math.ceil(math.log2(norm_val / theta))) + + # 缩放矩阵 + A_scaled = A / (2.0 ** s_val) + + # ===== 初始化 ===== + dtype = A.dtype + + if batch_mode: + I = torch.eye(n, dtype=dtype, device=device).unsqueeze(0).expand(batch_size, n, n) + else: + I = torch.eye(n, dtype=dtype, device=device) + + # ===== 10 阶泰勒级数计算 ===== + # exp(A_scaled) = sum_{k=0}^{10} A_scaled^k / k! + + result = I.clone() # 初始值为 I (k=0 项) + power_k = I.clone() # 当前 A_scaled^k + temp = I.clone() + + # 泰勒系数 1/k! + factorial = [1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0, 3628800.0] + for k in range(1, 11): + power_k = ntops.torch.matmul(power_k, A_scaled) + + # result += power_k / k! + coeff = 1.0 / factorial[k] + # result = result + coeff * power_k + if A.ndim == 2: + result = ntops.torch.addmm(result, I, power_k, beta=1.0, alpha=coeff) + else: + # batch 模式下逐 batch 计算 addmm + for b in range(batch_size): + result[b] = ntops.torch.addmm(result[b], I[b], power_k[b], beta=1.0, alpha=coeff) + + # ===== 平方还原 exp(A) = (exp(A_scaled))^(2^s) ===== + for _ in range(s_val): + temp = result + result = ntops.torch.matmul(result, temp) + + # ===== 转换回原始类型 ===== + result = result.to(original_dtype) + + return result diff --git a/tests/test_matrix_exp.py b/tests/test_matrix_exp.py new file mode 100644 index 0000000..59a2766 --- /dev/null +++ b/tests/test_matrix_exp.py @@ -0,0 +1,45 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def generate_arguments(): + arguments = [] + + for dtype in (torch.float32, torch.bfloat16): + device = "cuda" + + if dtype is torch.float32: + atol = 0.01 + rtol = 0.01 + else: + atol = 0.05 + rtol = 0.05 + + def generate_random_size(): + return random.randint(16, 64) + + m = generate_random_size() + n = generate_random_size() + k = generate_random_size() + + arguments.append((m, n, k, dtype, device, rtol, atol)) + + return "m, n, k, dtype, device, rtol, atol", arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_mm(m, n, k, dtype, device, rtol, atol): + input = torch.randn((m, n, n), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.matrix_exp(input) + reference_output = torch.linalg.matrix_exp(input) + + print(reference_output) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)