Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,8 @@ sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"
use_splash_scheduler: false # to use tokamax splash attention scheduler.
sa_fuse_reciprocal: true # defaults to true in Tokamax
sa_use_base2_exp: true # defaults to true in Tokamax
# local_sa_* variants apply to local (sliding window) attention layers;
# if None, each inherits from the corresponding global sa_* flag.
local_sa_block_q: None # inherits from sa_block_q if None
Expand All @@ -1116,6 +1118,8 @@ local_sa_q_layout: None # inherits from sa_q_layout if None
local_sa_k_layout: None # inherits from sa_k_layout if None
local_sa_v_layout: None # inherits from sa_v_layout if None
local_use_splash_scheduler: None # inherits from use_splash_scheduler if None
local_sa_fuse_reciprocal: None # inherits from sa_fuse_reciprocal if None
local_sa_use_base2_exp: None # inherits from sa_use_base2_exp if None
use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as max logit estimate
cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward)
cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward)
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,8 @@ class SplashAttention(BaseModel):
sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.")
sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.")
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")
sa_fuse_reciprocal: bool = Field(True, description="Maps to fuse_reciprocal in SplashConfig.")
sa_use_base2_exp: bool = Field(True, description="Maps to use_base2_exp in SplashConfig.")
# If None, each local_sa_* flag inherits from the corresponding sa_* flag.
local_sa_block_q: int | None = Field(None, description="Block size for Q in local splash attention.")
local_sa_block_kv: int | None = Field(None, description="Block size for KV in local splash attention.")
Expand All @@ -718,6 +720,8 @@ class SplashAttention(BaseModel):
local_sa_k_layout: str | None = Field(None, description="Layout for K in local splash attention.")
local_sa_v_layout: str | None = Field(None, description="Layout for V in local splash attention.")
local_use_splash_scheduler: bool | None = Field(None, description="Use experimental local splash attention scheduler.")
local_sa_fuse_reciprocal: bool | None = Field(None, description="Maps to local fuse_reciprocal in SplashConfig.")
local_sa_use_base2_exp: bool | None = Field(None, description="Maps to local use_base2_exp in SplashConfig.")
use_max_logit_estimate: int = Field(
-1,
description="-1 means no estimate, any > 0 value will be used as max logit estimate",
Expand Down Expand Up @@ -2984,6 +2988,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.local_sa_v_layout = self.sa_v_layout
if self.local_use_splash_scheduler is None:
self.local_use_splash_scheduler = self.use_splash_scheduler
if self.local_sa_fuse_reciprocal is None:
self.local_sa_fuse_reciprocal = self.sa_fuse_reciprocal
if self.local_sa_use_base2_exp is None:
self.local_sa_use_base2_exp = self.sa_use_base2_exp

# I. RUN ALL CROSS-FIELD VALIDATIONS
if self.load_parameters_path and self.load_full_state_path:
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ def __init__(
self.k_layout = self.config.local_sa_k_layout
self.v_layout = self.config.local_sa_v_layout
self.use_splash_scheduler = self.config.local_use_splash_scheduler
self.fuse_reciprocal = self.config.local_sa_fuse_reciprocal
self.use_base2_exp = self.config.local_sa_use_base2_exp
else:
self.block_q = self.config.sa_block_q
self.block_kv = self.config.sa_block_kv
Expand All @@ -509,6 +511,8 @@ def __init__(
self.k_layout = self.config.sa_k_layout
self.v_layout = self.config.sa_v_layout
self.use_splash_scheduler = self.config.use_splash_scheduler
self.fuse_reciprocal = self.config.sa_fuse_reciprocal
self.use_base2_exp = self.config.sa_use_base2_exp
self.attn_logits_soft_cap = attn_logits_soft_cap
self.sliding_window_size = sliding_window_size
self.chunk_attn_window_size = chunk_attn_window_size
Expand Down Expand Up @@ -1226,6 +1230,8 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
k_layout=tokamax_splash_kernel.QKVLayout[self.k_layout],
v_layout=tokamax_splash_kernel.QKVLayout[self.v_layout],
attn_logits_soft_cap=attn_logits_soft_cap,
fuse_reciprocal=self.fuse_reciprocal,
use_base2_exp=self.use_base2_exp,
residual_checkpoint_name="context",
fwd_cost_estimate=pl.CostEstimate(
flops=config.cost_estimate_flops_fwd,
Expand Down
Loading