Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -29,6 +29,7 @@
runtime/WebGPUBackend.cpp runtime/WebGPUGraph.cpp
runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp
runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
2 changes: 1 addition & 1 deletion backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Result<DelegateHandle*> WebGPUBackend::init(
}

try {
graph->build(flatbuffer_data, constant_data);
graph->build(flatbuffer_data, constant_data, context.get_named_data_map());
} catch (const std::exception& e) {
ET_LOG(Error, "WebGPU graph build failed: %s", e.what());
graph->~WebGPUGraph();
Expand Down
15 changes: 14 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/serialization/schema_generated.h>
#include <executorch/runtime/core/named_data_map.h>

#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
#include <webgpu/wgpu.h>
Expand Down Expand Up @@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() {

void WebGPUGraph::build(
const void* flatbuffer_data,
const uint8_t* constant_data) {
const uint8_t* constant_data,
const executorch::runtime::NamedDataMap* named_data_map) {
if (!device_) {
auto* ctx = get_default_webgpu_context();
if (ctx) {
Expand Down Expand Up @@ -165,6 +167,17 @@ void WebGPUGraph::build(
const uint8_t* src = constant_data + vk_bytes->offset();
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, src, tensor.nbytes);
} else if (
vk_bytes->named_key() != nullptr &&
named_data_map != nullptr) {
// Constant stored in the PTE named-data map.
auto buf =
named_data_map->get_data(vk_bytes->named_key()->c_str());
if (buf.ok() && buf->size() >= tensor.nbytes) {
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, buf->data(), tensor.nbytes);
buf->Free();
}
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
#include <unordered_map>
#include <vector>

namespace executorch {
namespace runtime {
class NamedDataMap;
} // namespace runtime
} // namespace executorch

namespace executorch {
namespace backends {
namespace webgpu {
Expand Down Expand Up @@ -66,7 +72,10 @@ class WebGPUGraph {

// Build the graph from a deserialized VkGraph flatbuffer and constant data.
// The flatbuffer_data pointer must remain valid during build().
void build(const void* flatbuffer_data, const uint8_t* constant_data);
void build(
const void* flatbuffer_data,
const uint8_t* constant_data,
const executorch::runtime::NamedDataMap* named_data_map = nullptr);

// Copy input tensor data from host pointers into GPU buffers.
void copy_inputs(const std::vector<std::pair<const void*, size_t>>& inputs);
Expand Down
192 changes: 192 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>

namespace executorch {
namespace backends {
namespace webgpu {

namespace {

// Uniform buffer layout matching the WGSL Params struct.
// Must be 16-byte aligned for WebGPU uniform buffer requirements.
struct RmsNormParams {
uint32_t num_rows;
uint32_t row_width;
float epsilon;
uint32_t _pad;
};
static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes");

void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// et_vk.rms_norm.default args: [in, weight, eps, out]
const int in_id = args.at(0);
const int weight_id = args.at(1);
const int eps_id = args.at(2);
const int out_id = args.at(3);

WGPUDevice device = graph.device();

// Get epsilon (Double from a Python float; defaults to 0.0)
float epsilon = 0.0f;
if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) {
epsilon = static_cast<float>(graph.get_double(eps_id));
} else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) {
epsilon = static_cast<float>(graph.get_int(eps_id));
}

// row_width = last dim; num_rows = product of the rest (PyTorch NCHW order)
const auto& in_tensor = graph.get_tensor(in_id);
if (in_tensor.dims.empty() || in_tensor.nbytes == 0) {
return;
}
const uint32_t row_width = static_cast<uint32_t>(in_tensor.dims.back());
if (row_width == 0) {
return;
}
uint64_t in_numel = 1;
for (int64_t d : in_tensor.dims) {
in_numel *= static_cast<uint64_t>(d);
}
// fp32-only shader: bail if the bytes don't match an fp32 element count.
if (in_tensor.nbytes != in_numel * sizeof(float)) {
return;
}
const uint32_t num_rows = static_cast<uint32_t>(in_numel / row_width);
if (num_rows == 0) {
return;
}

// Create uniform buffer for params
RmsNormParams params = {};
params.num_rows = num_rows;
params.row_width = row_width;
params.epsilon = epsilon;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(RmsNormParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
void* mapped =
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams));
std::memcpy(mapped, &params, sizeof(RmsNormParams));
wgpuBufferUnmap(uniform_buffer);

graph.add_uniform_buffer_bytes(sizeof(RmsNormParams));

// Create shader module from built-in WGSL source
WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Create bind group layout: out (rw) + in/weight (ro storage) + params
WGPUBindGroupLayoutEntry entries[4] = {};

// t_out - storage buffer, read-write
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_Storage;

// t_in - storage buffer, read-only
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

// t_weight - storage buffer, read-only
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

// params - uniform buffer
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 4;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

// Create pipeline layout
WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

// Create compute pipeline
WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

// Create bind group with actual buffers
const auto& out_tensor = graph.get_tensor(out_id);
const auto& weight_tensor = graph.get_tensor(weight_id);

WGPUBindGroupEntry bg_entries[4] = {};

bg_entries[0].binding = 0;
bg_entries[0].buffer = out_tensor.buffer;
bg_entries[0].size = out_tensor.nbytes;

bg_entries[1].binding = 1;
bg_entries[1].buffer = in_tensor.buffer;
bg_entries[1].size = in_tensor.nbytes;

bg_entries[2].binding = 2;
bg_entries[2].buffer = weight_tensor.buffer;
bg_entries[2].size = weight_tensor.nbytes;

bg_entries[3].binding = 3;
bg_entries[3].buffer = uniform_buffer;
bg_entries[3].size = sizeof(RmsNormParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 4;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

// One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row)
static_assert(
kRmsNormWorkgroupSize == 64,
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
graph.add_dispatch({pipeline, bind_group, num_rows});

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// uniform_buffer is kept alive by the bind group
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl);
}

} // namespace webgpu
} // namespace backends
} // namespace executorch
72 changes: 72 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;

struct Params {
num_rows: u32,
row_width: u32,
epsilon: f32,
_pad: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

const WG_SIZE: u32 = 64u;

var<workgroup> shared_sum: array<f32, WG_SIZE>;

fn reduce_shared(worker_id: u32) {
workgroupBarrier();
var stride: u32 = WG_SIZE / 2u;
loop {
if (stride == 0u) {
break;
}
if (worker_id < stride) {
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
}
workgroupBarrier();
stride = stride >> 1u;
}
}

@compute @workgroup_size(64, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = wid.x;
let worker_id = lid.x;

if (row_idx >= params.num_rows) {
return;
}

let base = row_idx * params.row_width;

var local_sq_sum: f32 = 0.0;
var x: u32 = worker_id;
loop {
if (x >= params.row_width) {
break;
}
let v = t_in[base + x];
local_sq_sum = local_sq_sum + v * v;
x = x + WG_SIZE;
}

shared_sum[worker_id] = local_sq_sum;
reduce_shared(worker_id);

let mean_sq = shared_sum[0] / f32(params.row_width);
let rstd = inverseSqrt(mean_sq + params.epsilon);

x = worker_id;
loop {
if (x >= params.row_width) {
break;
}
let v = t_in[base + x];
let w = t_weight[x];
t_out[base + x] = v * rstd * w;
x = x + WG_SIZE;
}
}
Loading
Loading