Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,42 @@ def linear_dq8ca_q4gsw(
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)


#######################
## linear_dq8ca_q8csw ##
#######################


def linear_dq8ca_q8csw(
x: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
# Per-channel symmetric INT8 weight: dequant = weight.to(fp) * scales (per output channel)
weights_dq = weights.to(x.dtype) * weight_scales.unsqueeze(-1)
return torch.nn.functional.linear(x, weights_dq, bias)


name = "linear_dq8ca_q8csw"
lib.define(
f"""
{name}(
Tensor input,
Tensor input_scales,
Tensor input_zp,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_dq8ca_q8csw, "CompositeExplicitAutograd")
linear_dq8ca_q8csw_op = getattr(getattr(torch.ops, namespace), name)

#################
## qaqw_linear ##
#################
Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,23 @@ def register_linear_dq8ca_q4gsw():
)


@update_features(exir_ops.edge.et_vk.linear_dq8ca_q8csw.default)
def register_linear_dq8ca_q8csw():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.WIDTH_PACKED_TEXTURE, # input_scale
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # bias (prepacked)
],
inputs_dtypes=utils.FP_T,
supports_prepacking=True,
)


# =============================================================================
# QuantizeDequantize.cpp
# =============================================================================
Expand Down
99 changes: 94 additions & 5 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,14 @@ def is_weight_pergroup_quantized(self) -> bool:
def is_weight_perchannel_quantized(self) -> bool:
weight_shape = self.weight_node.meta["val"].shape
scales_shape = self.weight_scales_node.meta["val"].shape
if len(scales_shape) != 1:
return False

# scales should have same size as weight's output channels dim
return scales_shape[0] == weight_shape[-2]
# Standard PT2E per-channel: scales is 1D [N].
if len(scales_shape) == 1:
return scales_shape[0] == weight_shape[-2]
# torchao source-transform with PerAxis(0) produces 2D [N, 1] (a
# single "group" covering the whole row). Treat that as per-channel.
if len(scales_shape) == 2 and scales_shape[-1] == 1:
return scales_shape[-2] == weight_shape[-2]
return False

def is_input_static_per_tensor_quantized(self) -> bool:
if self.dequantize_input_node is None:
Expand Down Expand Up @@ -489,6 +492,85 @@ def make_linear_dq8ca_q4gsw_op(
match.output_node.replace_all_uses_with(qlinear_node)


def make_linear_dq8ca_q8csw_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
weight_scales_tensor: torch.Tensor,
):
# Per-channel symmetric INT8 weight: no group_size, no nibble packing.
# Align width to 4 so GPU shader reads don't go OOB.
utils.align_width_and_update_state_dict(
ep,
match.weight_node,
weight_tensor,
align_to=1,
force_update=True,
)

# torchao source-transform produces 2D [N, 1] scales; squeeze to 1D [N]
# so the runtime sees the same shape as the standard PT2E per-channel
# path.
if weight_scales_tensor.dim() == 2 and weight_scales_tensor.shape[-1] == 1:
weight_scales_tensor = weight_scales_tensor.squeeze(-1).contiguous()

utils.align_width_and_update_state_dict(
ep,
match.weight_scales_node,
weight_scales_tensor,
align_to=1,
force_update=True,
)

if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)

# Pre-compute per-output-channel weight sums for input zero-point
# correction during integer accumulation.
first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()
# Pad OC to multiple of 4 to keep shader loads in-bounds
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = F.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()

sums_name = weight_tensor_name + "_sums"
sums_name = sums_name.replace(".", "_")
weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=sums_name,
data=sum_per_output_channel,
)

with graph_module.graph.inserting_before(match.output_node):
qlinear_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_dq8ca_q8csw.default,
args=(
match.pattern_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.bias_node,
),
)

qlinear_node.meta["val"] = match.output_node.meta["val"]
match.output_node.replace_all_uses_with(qlinear_node)


def make_linear_q8ta_q8csw_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
Expand Down Expand Up @@ -670,6 +752,13 @@ def replace_quantized_linear_patterns(
make_linear_dq8ca_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_dynamic_perchannel_quantized()
and match.is_weight_perchannel_quantized()
):
make_linear_dq8ca_q8csw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
Expand Down
Loading
Loading