diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index ab2da24a569..67b1655d94c 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -29,6 +29,7 @@ set(WEBGPU_SRCS runtime/WebGPUBackend.cpp runtime/WebGPUGraph.cpp runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp + runtime/ops/rms_norm/RmsNorm.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index 5321c20aaa4..b4e3165d8f4 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -76,7 +76,7 @@ Result WebGPUBackend::init( } try { - graph->build(flatbuffer_data, constant_data); + graph->build(flatbuffer_data, constant_data, context.get_named_data_map()); } catch (const std::exception& e) { ET_LOG(Error, "WebGPU graph build failed: %s", e.what()); graph->~WebGPUGraph(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 91404fb164f..1ea67ae1109 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() { void WebGPUGraph::build( const void* flatbuffer_data, - const uint8_t* constant_data) { + const uint8_t* constant_data, + const executorch::runtime::NamedDataMap* named_data_map) { if (!device_) { auto* ctx = get_default_webgpu_context(); if (ctx) { @@ -165,6 +167,17 @@ void WebGPUGraph::build( const uint8_t* src = constant_data + vk_bytes->offset(); wgpuQueueWriteBuffer( queue_, tensor.buffer, 0, src, tensor.nbytes); + } else if ( + vk_bytes->named_key() != nullptr && + named_data_map != nullptr) { + // Constant stored in the PTE named-data map. + auto buf = + named_data_map->get_data(vk_bytes->named_key()->c_str()); + if (buf.ok() && buf->size() >= tensor.nbytes) { + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, buf->data(), tensor.nbytes); + buf->Free(); + } } } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 3aa96917a4e..fa171906b67 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -15,6 +15,12 @@ #include #include +namespace executorch { +namespace runtime { +class NamedDataMap; +} // namespace runtime +} // namespace executorch + namespace executorch { namespace backends { namespace webgpu { @@ -66,7 +72,10 @@ class WebGPUGraph { // Build the graph from a deserialized VkGraph flatbuffer and constant data. // The flatbuffer_data pointer must remain valid during build(). - void build(const void* flatbuffer_data, const uint8_t* constant_data); + void build( + const void* flatbuffer_data, + const uint8_t* constant_data, + const executorch::runtime::NamedDataMap* named_data_map = nullptr); // Copy input tensor data from host pointers into GPU buffers. void copy_inputs(const std::vector>& inputs); diff --git a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp new file mode 100644 index 00000000000..47067fec944 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace webgpu { + +namespace { + +// Uniform buffer layout matching the WGSL Params struct. +// Must be 16-byte aligned for WebGPU uniform buffer requirements. +struct RmsNormParams { + uint32_t num_rows; + uint32_t row_width; + float epsilon; + uint32_t _pad; +}; +static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes"); + +void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { + // et_vk.rms_norm.default args: [in, weight, eps, out] + const int in_id = args.at(0); + const int weight_id = args.at(1); + const int eps_id = args.at(2); + const int out_id = args.at(3); + + WGPUDevice device = graph.device(); + + // Get epsilon (Double from a Python float; defaults to 0.0) + float epsilon = 0.0f; + if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) { + epsilon = static_cast(graph.get_double(eps_id)); + } else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) { + epsilon = static_cast(graph.get_int(eps_id)); + } + + // row_width = last dim; num_rows = product of the rest (PyTorch NCHW order) + const auto& in_tensor = graph.get_tensor(in_id); + if (in_tensor.dims.empty() || in_tensor.nbytes == 0) { + return; + } + const uint32_t row_width = static_cast(in_tensor.dims.back()); + if (row_width == 0) { + return; + } + uint64_t in_numel = 1; + for (int64_t d : in_tensor.dims) { + in_numel *= static_cast(d); + } + // fp32-only shader: bail if the bytes don't match an fp32 element count. + if (in_tensor.nbytes != in_numel * sizeof(float)) { + return; + } + const uint32_t num_rows = static_cast(in_numel / row_width); + if (num_rows == 0) { + return; + } + + // Create uniform buffer for params + RmsNormParams params = {}; + params.num_rows = num_rows; + params.row_width = row_width; + params.epsilon = epsilon; + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(RmsNormParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams)); + std::memcpy(mapped, ¶ms, sizeof(RmsNormParams)); + wgpuBufferUnmap(uniform_buffer); + + graph.add_uniform_buffer_bytes(sizeof(RmsNormParams)); + + // Create shader module from built-in WGSL source + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN}; + + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Create bind group layout: out (rw) + in/weight (ro storage) + params + WGPUBindGroupLayoutEntry entries[4] = {}; + + // t_out - storage buffer, read-write + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + + // t_in - storage buffer, read-only + entries[1].binding = 1; + entries[1].visibility = WGPUShaderStage_Compute; + entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + + // t_weight - storage buffer, read-only + entries[2].binding = 2; + entries[2].visibility = WGPUShaderStage_Compute; + entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + + // params - uniform buffer + entries[3].binding = 3; + entries[3].visibility = WGPUShaderStage_Compute; + entries[3].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 4; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + // Create pipeline layout + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + // Create compute pipeline + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + // Create bind group with actual buffers + const auto& out_tensor = graph.get_tensor(out_id); + const auto& weight_tensor = graph.get_tensor(weight_id); + + WGPUBindGroupEntry bg_entries[4] = {}; + + bg_entries[0].binding = 0; + bg_entries[0].buffer = out_tensor.buffer; + bg_entries[0].size = out_tensor.nbytes; + + bg_entries[1].binding = 1; + bg_entries[1].buffer = in_tensor.buffer; + bg_entries[1].size = in_tensor.nbytes; + + bg_entries[2].binding = 2; + bg_entries[2].buffer = weight_tensor.buffer; + bg_entries[2].size = weight_tensor.nbytes; + + bg_entries[3].binding = 3; + bg_entries[3].buffer = uniform_buffer; + bg_entries[3].size = sizeof(RmsNormParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 4; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + // One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row) + static_assert( + kRmsNormWorkgroupSize == 64, + "must match @workgroup_size and WG_SIZE in rms_norm.wgsl"); + graph.add_dispatch({pipeline, bind_group, num_rows}); + + // Release intermediate objects (pipeline and bind_group are kept by dispatch) + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // uniform_buffer is kept alive by the bind group +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl); +} + +} // namespace webgpu +} // namespace backends +} // namespace executorch diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl b/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl new file mode 100644 index 00000000000..4bd5618596f --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl @@ -0,0 +1,72 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_weight: array; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let base = row_idx * params.row_width; + + var local_sq_sum: f32 = 0.0; + var x: u32 = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + local_sq_sum = local_sq_sum + v * v; + x = x + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + let w = t_weight[x]; + t_out[base + x] = v * rstd * w; + x = x + WG_SIZE; + } +} diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h b/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h new file mode 100644 index 00000000000..982d56b84db --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace webgpu { + +// WGSL shader source for rms_norm: y = x * w * rsqrt(mean(x^2) + eps) +inline constexpr const char* kRmsNormWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_weight: array; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let base = row_idx * params.row_width; + + var local_sq_sum: f32 = 0.0; + var x: u32 = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + local_sq_sum = local_sq_sum + v * v; + x = x + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + let w = t_weight[x]; + t_out[base + x] = v * rstd * w; + x = x + WG_SIZE; + } +} +)"; + +inline constexpr uint32_t kRmsNormWorkgroupSize = 64; + +} // namespace webgpu +} // namespace backends +} // namespace executorch diff --git a/backends/webgpu/test/ops/rms_norm/__init__.py b/backends/webgpu/test/ops/rms_norm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py new file mode 100644 index 00000000000..6b07724de57 --- /dev/null +++ b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""fp32 RMSNorm export tests via VulkanPartitioner. + +Verifies the export side only; numerics are checked in the native test +`test/test_webgpu_native.cpp`. +""" + +import unittest + +import torch +from executorch.backends.vulkan import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + + +class RmsNormModule(torch.nn.Module): + """Standard RMSNorm with learnable per-feature weight.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_f32 = x.to(torch.float32) + var = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_norm = x_f32 * torch.rsqrt(var + self.eps) + return (x_norm * self.weight).to(x.dtype) + + +class TestRmsNorm(unittest.TestCase): + def _export_and_check(self, model, example_inputs) -> None: + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + found_vulkan = False + for plan in et_program.executorch_program.execution_plan: + for delegate in plan.delegates: + if delegate.id == "VulkanBackend": + found_vulkan = True + break + self.assertTrue(found_vulkan, "Expected VulkanBackend delegate in .pte") + self.assertGreater(len(et_program.buffer), 100) + + def test_rms_norm_basic_small(self) -> None: + self._export_and_check(RmsNormModule(64), (torch.randn(1, 1, 1, 64),)) + + def test_rms_norm_llm_hidden(self) -> None: + # LLM-typical hidden size. + self._export_and_check(RmsNormModule(896), (torch.randn(1, 1, 1, 896),)) + + def test_rms_norm_multi_row(self) -> None: + # Multiple rows along the seq-len dimension (prefill-style). + self._export_and_check(RmsNormModule(896), (torch.randn(1, 1, 7, 896),)) + + def test_rms_norm_4d(self) -> None: + # 4D shape similar to QK norm with multiple Z slices. + self._export_and_check(RmsNormModule(128), (torch.randn(1, 5, 4, 128),)) + + +def export_rms_norm_model(output_path: str) -> None: + """Export a fixed-seed RMSNorm model to .pte for the native runtime test.""" + torch.manual_seed(0) + hidden = 896 + seq_len = 7 + model = RmsNormModule(hidden, eps=1e-6) + # Fix the weight to a known value the native test reconstructs. + with torch.no_grad(): + model.weight.copy_( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + example_inputs = (torch.randn(1, 1, seq_len, hidden),) + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported {output_path}") + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/webgpu/test/test_build_webgpu.sh b/backends/webgpu/test/test_build_webgpu.sh index a42b2304ee7..6c90b275924 100755 --- a/backends/webgpu/test/test_build_webgpu.sh +++ b/backends/webgpu/test/test_build_webgpu.sh @@ -17,20 +17,32 @@ NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu) # ── Step 1: Python export tests ────────────────────────────────────────────── -echo "=== Step 1: Run Python export test ===" +echo "=== Step 1: Run Python export tests ===" $PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/ops/add/test_add.py" -v +# Non-fatal: a rms_norm pytest failure skips the rms_norm native test below +# rather than aborting the whole run. +RMS_NORM_PYTEST_OK=1 +$PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/ops/rms_norm/test_rms_norm.py" -v \ + || RMS_NORM_PYTEST_OK=0 # ── Step 2: Export .pte model ───────────────────────────────────────────────── echo "=== Step 2: Export test models ===" PTE_MODEL="/tmp/webgpu_add_test.pte" PTE_CHAINED_MODEL="/tmp/webgpu_chained_add_test.pte" +PTE_RMS_NORM_MODEL="/tmp/webgpu_rms_norm_test.pte" cd "${EXECUTORCH_ROOT}" $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.add.test_add import export_add_model, export_chained_add_model export_add_model('${PTE_MODEL}') export_chained_add_model('${PTE_CHAINED_MODEL}') " +if [[ "${RMS_NORM_PYTEST_OK}" == "1" ]]; then + $PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.rms_norm.test_rms_norm import export_rms_norm_model +export_rms_norm_model('${PTE_RMS_NORM_MODEL}') +" || { echo "WARN: rms_norm export failed; skipping rms_norm native test"; RMS_NORM_PYTEST_OK=0; } +fi # ── Step 3: Native build + test (wgpu-native) ──────────────────────────────── @@ -61,8 +73,16 @@ cmake \ cmake --build "${NATIVE_BUILD_DIR}" --target webgpu_native_test -j${NPROC} echo "=== Step 4: Run native test ===" -WEBGPU_TEST_MODEL="${PTE_MODEL}" \ -WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \ +RMS_NORM_ENV_VAR="" +if [[ "${RMS_NORM_PYTEST_OK}" == "1" && -f "${PTE_RMS_NORM_MODEL}" ]]; then + RMS_NORM_ENV_VAR="WEBGPU_TEST_RMS_NORM_MODEL=${PTE_RMS_NORM_MODEL}" +else + echo "(skipping rms_norm native test: pytest or export did not complete)" +fi +env \ + WEBGPU_TEST_MODEL="${PTE_MODEL}" \ + WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \ + ${RMS_NORM_ENV_VAR} \ "${NATIVE_BUILD_DIR}/backends/webgpu/webgpu_native_test" echo "=== Done ===" diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index d3005debf37..097299eacae 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -10,10 +10,12 @@ #include #include +#include #include #include #include #include +#include using namespace executorch::backends::webgpu; using namespace executorch::extension; @@ -131,6 +133,92 @@ static bool test_chained_add(const std::string& model_path) { return true; } +static bool test_rms_norm(const std::string& model_path) { + // rms_norm over (1,1,7,896), eps=1e-6; weight=linspace(0.5,1.5,896). CPU + // reference below. + printf("\n--- Test: rms_norm (1x1x7x896, eps=1e-6) ---\n"); + + Module module(model_path); + auto err = module.load_forward(); + if (err != Error::Ok) { + printf("FAIL: could not load forward method (error %d)\n", (int)err); + return false; + } + printf("Model loaded: %s\n", model_path.c_str()); + + constexpr int hidden = 896; + constexpr int seq_len = 7; + constexpr int num_rows = seq_len; // batch=1, channels=1 + constexpr int numel = num_rows * hidden; + constexpr float eps = 1e-6f; + + // Deterministic input: linear ramp scaled to [-1, 1]. + std::vector x_data(numel); + for (int i = 0; i < numel; i++) { + x_data[i] = 2.0f * (static_cast(i) / static_cast(numel - 1)) - 1.0f; + } + + // Reconstruct weight = torch.linspace(0.5, 1.5, hidden). + std::vector w_data(hidden); + for (int i = 0; i < hidden; i++) { + w_data[i] = 0.5f + (static_cast(i) / static_cast(hidden - 1)); + } + + // CPU reference: per-row rsqrt(mean(x^2) + eps) scaled by w. + std::vector ref_data(numel); + for (int r = 0; r < num_rows; r++) { + const int off = r * hidden; + float sq_sum = 0.0f; + for (int i = 0; i < hidden; i++) { + const float v = x_data[off + i]; + sq_sum += v * v; + } + const float rstd = 1.0f / std::sqrt(sq_sum / static_cast(hidden) + eps); + for (int i = 0; i < hidden; i++) { + ref_data[off + i] = x_data[off + i] * rstd * w_data[i]; + } + } + + auto x = make_tensor_ptr({1, 1, seq_len, hidden}, std::vector(x_data)); + auto result = module.forward({EValue(x)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + + const auto& out_tensor = outputs[0].toTensor(); + if (out_tensor.numel() != numel) { + printf("FAIL: output numel %zu != expected %d\n", + (size_t)out_tensor.numel(), numel); + return false; + } + const float* out_data = out_tensor.const_data_ptr(); + + float max_abs_err = 0.0f; + float max_rel_err = 0.0f; + for (int i = 0; i < numel; i++) { + const float abs_err = std::abs(out_data[i] - ref_data[i]); + max_abs_err = std::max(max_abs_err, abs_err); + const float denom = std::max(std::abs(ref_data[i]), 1e-6f); + max_rel_err = std::max(max_rel_err, abs_err / denom); + } + + printf("Max abs error: %e Max rel error: %e (checked %d elements)\n", + max_abs_err, max_rel_err, numel); + if (max_abs_err > 1e-3f || max_rel_err > 1e-3f) { + printf("FAIL: error exceeds tolerance 1e-3\n"); + return false; + } + printf("PASS: rms_norm test\n"); + return true; +} + int main(int argc, char** argv) { std::string model_path = "webgpu_add_test.pte"; if (argc > 1) { @@ -145,6 +233,11 @@ int main(int argc, char** argv) { chained_model_path = env; } + std::string rms_norm_model_path; + if (const char* env = std::getenv("WEBGPU_TEST_RMS_NORM_MODEL")) { + rms_norm_model_path = env; + } + WebGPUContext ctx; try { ctx = create_webgpu_context(); @@ -162,6 +255,10 @@ int main(int argc, char** argv) { ok = test_chained_add(chained_model_path) && ok; } + if (!rms_norm_model_path.empty()) { + ok = test_rms_norm(rms_norm_model_path) && ok; + } + set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx);