diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 6dc96f0517..8d19cee75c 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -42,7 +42,6 @@ "intermediate_source/torchrec_intro_tutorial.py", #failing with 2.8 reenable after 3498 "intermediate_source/torch_export_tutorial.py", # failing with 2.11 issue #3773 "beginner_source/mosaic_memory_profiling_tutorial.py", # failing with 2.11 issue #3774 - "intermediate_source/variable_length_attention_tutorial.py", # failing with 2.11 issue #3775 ] def tutorial_source_dirs() -> List[Path]: diff --git a/intermediate_source/variable_length_attention_tutorial.py b/intermediate_source/variable_length_attention_tutorial.py index 261695723e..7c452b9546 100644 --- a/intermediate_source/variable_length_attention_tutorial.py +++ b/intermediate_source/variable_length_attention_tutorial.py @@ -99,8 +99,10 @@ # cu_seq_k: torch.Tensor, # max_q: int, # max_k: int, -# is_causal: bool = False, +# *, # return_aux: AuxRequest | None = None, +# scale: float | None = None, +# window_size: tuple[int, int] = (-1, -1), # ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: # # ``query``, ``key``, and ``value`` correspond to the ``q``, ``k``, and @@ -108,18 +110,22 @@ # cumulative indices for query and key/value, respectively. These mark the # logical boundaries that separate the documents in our input. ``max_q`` # and ``max_k`` are the maximum sequence lengths of query and key, -# respectively. ``is_causal`` applies causal masking if set to True and -# ``return_aux`` specifies which auxiliary outputs to return (ie ``lse``). +# respectively. ``return_aux`` specifies which auxiliary outputs to return +# (ie ``lse``). ``scale`` is an optional scaling factor applied to the +# attention scores before softmax. ``window_size`` is a ``(left, right)`` +# tuple that controls sliding window attention: use ``(-1, -1)`` for full +# attention (default), ``(-1, 0)`` for causal attention, or ``(W, 0)`` +# for causal attention with a sliding window of size ``W``. ###################################################################### # **Note on causal masking** -# When ``is_causal`` is set to True, causal masking is applied which means -# that tokens can only attend to previous tokens. For bidirectional -# attention, set this flag to False. +# When ``window_size`` is set to ``(-1, 0)``, causal masking is applied +# which means that tokens can only attend to previous tokens. For +# bidirectional (full) attention, use the default ``(-1, -1)``. # # In torchtitan (PyTorch's pretraining framework), we set -# ``is_causal = True`` uniformly to prevent the model from cheating and -# artificially driving the loss down too quickly. +# ``window_size = (-1, 0)`` uniformly to prevent the model from cheating +# and artificially driving the loss down too quickly. ###################################################################### @@ -241,7 +247,7 @@ def forward( cu_seq_k=cu_seq, max_q=max_len, max_k=max_len, - is_causal=True, + window_size=(-1, 0), ) attn_out = attn_out.view(-1, self.embed_dim) attn_out = self.out_proj(attn_out)