diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 13fe9fa0d4..046792727d 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1101,6 +1101,8 @@ sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" use_splash_scheduler: false # to use tokamax splash attention scheduler. +sa_fuse_reciprocal: true # defaults to true in Tokamax +sa_use_base2_exp: true # defaults to true in Tokamax # local_sa_* variants apply to local (sliding window) attention layers; # if None, each inherits from the corresponding global sa_* flag. local_sa_block_q: None # inherits from sa_block_q if None @@ -1116,6 +1118,8 @@ local_sa_q_layout: None # inherits from sa_q_layout if None local_sa_k_layout: None # inherits from sa_k_layout if None local_sa_v_layout: None # inherits from sa_v_layout if None local_use_splash_scheduler: None # inherits from use_splash_scheduler if None +local_sa_fuse_reciprocal: None # inherits from sa_fuse_reciprocal if None +local_sa_use_base2_exp: None # inherits from sa_use_base2_exp if None use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as max logit estimate cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward) cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 8ab8ee51ed..7b53616e5e 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -700,6 +700,8 @@ class SplashAttention(BaseModel): sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.") sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.") use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.") + sa_fuse_reciprocal: bool = Field(True, description="Maps to fuse_reciprocal in SplashConfig.") + sa_use_base2_exp: bool = Field(True, description="Maps to use_base2_exp in SplashConfig.") # If None, each local_sa_* flag inherits from the corresponding sa_* flag. local_sa_block_q: int | None = Field(None, description="Block size for Q in local splash attention.") local_sa_block_kv: int | None = Field(None, description="Block size for KV in local splash attention.") @@ -718,6 +720,8 @@ class SplashAttention(BaseModel): local_sa_k_layout: str | None = Field(None, description="Layout for K in local splash attention.") local_sa_v_layout: str | None = Field(None, description="Layout for V in local splash attention.") local_use_splash_scheduler: bool | None = Field(None, description="Use experimental local splash attention scheduler.") + local_sa_fuse_reciprocal: bool | None = Field(None, description="Maps to local fuse_reciprocal in SplashConfig.") + local_sa_use_base2_exp: bool | None = Field(None, description="Maps to local use_base2_exp in SplashConfig.") use_max_logit_estimate: int = Field( -1, description="-1 means no estimate, any > 0 value will be used as max logit estimate", @@ -2984,6 +2988,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.local_sa_v_layout = self.sa_v_layout if self.local_use_splash_scheduler is None: self.local_use_splash_scheduler = self.use_splash_scheduler + if self.local_sa_fuse_reciprocal is None: + self.local_sa_fuse_reciprocal = self.sa_fuse_reciprocal + if self.local_sa_use_base2_exp is None: + self.local_sa_use_base2_exp = self.sa_use_base2_exp # I. RUN ALL CROSS-FIELD VALIDATIONS if self.load_parameters_path and self.load_full_state_path: diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index f3eb515547..62a99381e8 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -495,6 +495,8 @@ def __init__( self.k_layout = self.config.local_sa_k_layout self.v_layout = self.config.local_sa_v_layout self.use_splash_scheduler = self.config.local_use_splash_scheduler + self.fuse_reciprocal = self.config.local_sa_fuse_reciprocal + self.use_base2_exp = self.config.local_sa_use_base2_exp else: self.block_q = self.config.sa_block_q self.block_kv = self.config.sa_block_kv @@ -509,6 +511,8 @@ def __init__( self.k_layout = self.config.sa_k_layout self.v_layout = self.config.sa_v_layout self.use_splash_scheduler = self.config.use_splash_scheduler + self.fuse_reciprocal = self.config.sa_fuse_reciprocal + self.use_base2_exp = self.config.sa_use_base2_exp self.attn_logits_soft_cap = attn_logits_soft_cap self.sliding_window_size = sliding_window_size self.chunk_attn_window_size = chunk_attn_window_size @@ -1226,6 +1230,8 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): k_layout=tokamax_splash_kernel.QKVLayout[self.k_layout], v_layout=tokamax_splash_kernel.QKVLayout[self.v_layout], attn_logits_soft_cap=attn_logits_soft_cap, + fuse_reciprocal=self.fuse_reciprocal, + use_base2_exp=self.use_base2_exp, residual_checkpoint_name="context", fwd_cost_estimate=pl.CostEstimate( flops=config.cost_estimate_flops_fwd,