Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ cc_library(
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:math",
"@highway//:profiler",
],
)
Expand Down
65 changes: 25 additions & 40 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "gemma/attention.h"
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/contrib/math/fast_math-inl.h"

HWY_BEFORE_NAMESPACE();
namespace gcpp {
Expand Down Expand Up @@ -523,25 +524,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
hn::StoreU(new_max, df4, old_max);
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
const VF new_max_i = hn::Set(df, old_max[i]);
x_p0 = hn::FastExp(df, hn::Sub(x_p0, new_max_i));
x_p1 = hn::FastExp(df, hn::Sub(x_p1, new_max_i));
};
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
apply_exp(0, x_0_p0, x_0_p1);
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
apply_exp(1, x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
apply_exp(2, x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
apply_exp(3, x_3_p0, x_3_p1);
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
Expand Down Expand Up @@ -592,10 +590,6 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
}
}

template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE VF CallExp(DF df, VF x_p0) {
return hn::Exp(df, x_p0);
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
Expand Down Expand Up @@ -649,45 +643,36 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
hn::StoreU(new_max, df8, old_max);

auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
const VF new_max_i = hn::Set(df, old_max[i]);
x_p0 = hn::Exp(df, hn::Sub(x_p0, new_max_i));
x_p1 = hn::Exp(df, hn::Sub(x_p1, new_max_i));
};

if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
apply_exp(0, x_0_p0, x_0_p1);
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
apply_exp(1, x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
apply_exp(2, x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
apply_exp(3, x_3_p0, x_3_p1);
}
if constexpr (kNumQueries >= 5) {
const VF new_max_0 = hn::Set(df, old_max[4]);
x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0));
x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0));
apply_exp(4, x_4_p0, x_4_p1);
}
if constexpr (kNumQueries >= 6) {
const VF new_max_0 = hn::Set(df, old_max[5]);
x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0));
x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0));
apply_exp(5, x_5_p0, x_5_p1);
}
if constexpr (kNumQueries >= 7) {
const VF new_max_0 = hn::Set(df, old_max[6]);
x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0));
x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0));
apply_exp(6, x_6_p0, x_6_p1);
}
if constexpr (kNumQueries >= 8) {
const VF new_max_0 = hn::Set(df, old_max[7]);
x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0));
x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0));
apply_exp(7, x_7_p0, x_7_p1);
}
VF8 old_d_vf = hn::Set(df8, 0.0f);
old_d_vf = hn::LoadU(df8, old_d);
Expand Down
6 changes: 3 additions & 3 deletions gemma/flash_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ void TestTiledFlashAttention() {
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f)
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f);
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-4f);
}
}
}
Expand Down Expand Up @@ -464,7 +464,7 @@ void TestTiledFlashAttentionBF16() {
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
Expand Down
Loading