From f692d6812f37a1038a88a63073103cb93eb8eea6 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 27 Feb 2026 11:00:35 +0800 Subject: [PATCH 1/2] =?UTF-8?q?Issue/243:=E6=94=AF=E6=8C=81w4a16=20awq=20f?= =?UTF-8?q?p16=E6=8E=A8=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/layers/fused_linear.cpp | 77 +++++++++++++++++++++++++++++++++ csrc/layers/fused_linear.hpp | 83 ++++++++++++++++++++++++------------ 2 files changed, 133 insertions(+), 27 deletions(-) diff --git a/csrc/layers/fused_linear.cpp b/csrc/layers/fused_linear.cpp index 6315ea2b..2ff5ffbb 100644 --- a/csrc/layers/fused_linear.cpp +++ b/csrc/layers/fused_linear.cpp @@ -170,6 +170,58 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const { 0, tp_rank_, tp_size_); } +infinicore::nn::Parameter QKVParallelLinear::get_q_weight_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_->narrow({{1, 0, q_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_weight_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_weight_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_zeros_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_zeros_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros_awq(int scaling_factor) const { + return infinicore::nn::Parameter( + weight_zeros_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}), + 1, tp_rank_, tp_size_); +} + infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const { return infinicore::nn::Parameter( weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_); @@ -320,4 +372,29 @@ bool GateUpParallelLinear::has_gate_bias() const { bool GateUpParallelLinear::has_up_bias() const { return up_bias_; } + +infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_awq() const { + return infinicore::nn::Parameter(weight_->narrow({{1, 0, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_awq() const { + return infinicore::nn::Parameter(weight_->narrow({{1, weight_->size(1) / 2, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale_awq() const { + return infinicore::nn::Parameter(weight_scale_->narrow({{1, 0, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale_awq() const { + return infinicore::nn::Parameter(weight_scale_->narrow({{1, weight_scale_->size(1) / 2, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros_awq() const { + return infinicore::nn::Parameter(weight_zeros_->narrow({{1, 0, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros_awq() const { + return infinicore::nn::Parameter(weight_zeros_->narrow({{1, weight_zeros_->size(1) / 2, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_); +} + } // namespace infinilm::layers diff --git a/csrc/layers/fused_linear.hpp b/csrc/layers/fused_linear.hpp index 75748fc6..e70094e0 100644 --- a/csrc/layers/fused_linear.hpp +++ b/csrc/layers/fused_linear.hpp @@ -58,6 +58,21 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { infinicore::nn::Parameter get_k_weight_zeros() const; infinicore::nn::Parameter get_v_weight_zeros() const; + // For computing the packing factor in awq quantization: + // Returns the number of low-bit elements packed into a single high-bit container element. + // For example: int4 → int32 yields a packing factor of 8 (32 bits / 4 bits = 8 int4 values per int32). + infinicore::nn::Parameter get_q_weight_awq(int scaling_factor) const; + infinicore::nn::Parameter get_k_weight_awq(int scaling_factor) const; + infinicore::nn::Parameter get_v_weight_awq(int scaling_factor) const; + + infinicore::nn::Parameter get_q_weight_scale_awq(int scaling_factor) const; + infinicore::nn::Parameter get_k_weight_scale_awq(int scaling_factor) const; + infinicore::nn::Parameter get_v_weight_scale_awq(int scaling_factor) const; + + infinicore::nn::Parameter get_q_weight_zeros_awq(int scaling_factor) const; + infinicore::nn::Parameter get_k_weight_zeros_awq(int scaling_factor) const; + infinicore::nn::Parameter get_v_weight_zeros_awq(int scaling_factor) const; + infinicore::nn::Parameter get_q_bias() const; infinicore::nn::Parameter get_k_bias() const; infinicore::nn::Parameter get_v_bias() const; @@ -132,6 +147,18 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { infinicore::nn::Parameter get_up_bias() const; + infinicore::nn::Parameter get_gate_weight_awq() const; + + infinicore::nn::Parameter get_up_weight_awq() const; + + infinicore::nn::Parameter get_up_weight_scale_awq() const; + + infinicore::nn::Parameter get_up_weight_zeros_awq() const; + + infinicore::nn::Parameter get_gate_weight_scale_awq() const; + + infinicore::nn::Parameter get_gate_weight_zeros_awq() const; + bool has_gate_bias() const; bool has_up_bias() const; @@ -178,22 +205,24 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { if (name##_->has_v_bias()) \ this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); -#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ - this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \ - this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \ - this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \ - this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \ - this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \ - this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \ - this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \ - this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \ - this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \ - if (name##_->has_q_bias()) \ - this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \ - if (name##_->has_k_bias()) \ - this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \ - if (name##_->has_v_bias()) \ +#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + auto awq_ptr = std::static_pointer_cast(this->quantization_); \ + int packing_num = awq_ptr->get_packing_num(); \ + this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \ + this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \ + this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale_awq(1)); \ + this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight_awq(packing_num)); \ + this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros_awq(packing_num)); \ + this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale_awq(1)); \ + this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight_awq(packing_num)); \ + this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros_awq(packing_num)); \ + this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale_awq(1)); \ + if (name##_->has_q_bias()) \ + this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \ + if (name##_->has_k_bias()) \ + this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \ + if (name##_->has_v_bias()) \ this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); // ========================= Gate-Up Quantization ============================== @@ -208,16 +237,16 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { if (name##_->has_up_bias()) \ this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); -#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ - this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \ - this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale()); \ - this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros()); \ - this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \ - this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale()); \ - this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros()); \ - if (name##_->has_gate_bias()) \ - this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \ - if (name##_->has_up_bias()) \ +#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight_awq()); \ + this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros_awq()); \ + this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale_awq()); \ + this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight_awq()); \ + this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros_awq()); \ + this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale_awq()); \ + if (name##_->has_gate_bias()) \ + this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \ + if (name##_->has_up_bias()) \ this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); } // namespace infinilm::layers From fc97bbd8ef5fd9ecc4de04186939c89c2a957efa Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Wed, 11 Mar 2026 16:32:34 +0800 Subject: [PATCH 2/2] Fix: Add lifecycle management to AWQ linear function --- csrc/layers/fused_linear.hpp | 2 +- csrc/models/llama/llama_attention.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/layers/fused_linear.hpp b/csrc/layers/fused_linear.hpp index e70094e0..2e1217b6 100644 --- a/csrc/layers/fused_linear.hpp +++ b/csrc/layers/fused_linear.hpp @@ -207,7 +207,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { #define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \ name##_ = std::make_shared(__VA_ARGS__); \ - auto awq_ptr = std::static_pointer_cast(this->quantization_); \ + auto awq_ptr = std::static_pointer_cast(name##_->get_quantization()); \ int packing_num = awq_ptr->get_packing_num(); \ this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \ this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \ diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index a6b5ab78..6fee7ce5 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -112,12 +112,13 @@ LlamaAttention::LlamaAttention(std::shared_ptr mo dtype, device, tp_rank, tp_size, rank_info.comm); break; - case infinicore::quantization::QuantScheme::AWQ_W4A16: + case infinicore::quantization::QuantScheme::AWQ_W4A16: { INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, rank_info); INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); break; + } default: INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, rank_info);