Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions csrc/layers/fused_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,58 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, 0, q_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
Expand Down Expand Up @@ -320,4 +372,29 @@ bool GateUpParallelLinear::has_gate_bias() const {
bool GateUpParallelLinear::has_up_bias() const {
return up_bias_;
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_awq() const {
return infinicore::nn::Parameter(weight_->narrow({{1, 0, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_awq() const {
return infinicore::nn::Parameter(weight_->narrow({{1, weight_->size(1) / 2, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale_awq() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{1, 0, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale_awq() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{1, weight_scale_->size(1) / 2, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros_awq() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{1, 0, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros_awq() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{1, weight_zeros_->size(1) / 2, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}

} // namespace infinilm::layers
83 changes: 56 additions & 27 deletions csrc/layers/fused_linear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
infinicore::nn::Parameter get_k_weight_zeros() const;
infinicore::nn::Parameter get_v_weight_zeros() const;

// For computing the packing factor in awq quantization:
// Returns the number of low-bit elements packed into a single high-bit container element.
// For example: int4 → int32 yields a packing factor of 8 (32 bits / 4 bits = 8 int4 values per int32).
infinicore::nn::Parameter get_q_weight_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_awq(int scaling_factor) const;

infinicore::nn::Parameter get_q_weight_scale_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_scale_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_scale_awq(int scaling_factor) const;

infinicore::nn::Parameter get_q_weight_zeros_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_zeros_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_zeros_awq(int scaling_factor) const;

infinicore::nn::Parameter get_q_bias() const;
infinicore::nn::Parameter get_k_bias() const;
infinicore::nn::Parameter get_v_bias() const;
Expand Down Expand Up @@ -132,6 +147,18 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {

infinicore::nn::Parameter get_up_bias() const;

infinicore::nn::Parameter get_gate_weight_awq() const;

infinicore::nn::Parameter get_up_weight_awq() const;

infinicore::nn::Parameter get_up_weight_scale_awq() const;

infinicore::nn::Parameter get_up_weight_zeros_awq() const;

infinicore::nn::Parameter get_gate_weight_scale_awq() const;

infinicore::nn::Parameter get_gate_weight_zeros_awq() const;

bool has_gate_bias() const;

bool has_up_bias() const;
Expand Down Expand Up @@ -178,22 +205,24 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(name##_->get_quantization()); \
int packing_num = awq_ptr->get_packing_num(); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale_awq(1)); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale_awq(1)); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale_awq(1)); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

// ========================= Gate-Up Quantization ==============================
Expand All @@ -208,16 +237,16 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());

#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight_awq()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros_awq()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale_awq()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight_awq()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros_awq()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale_awq()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
} // namespace infinilm::layers
3 changes: 2 additions & 1 deletion csrc/models/llama/llama_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,13 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
dtype, device, tp_rank, tp_size, rank_info.comm);
break;

case infinicore::quantization::QuantScheme::AWQ_W4A16:
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
default:
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
Expand Down