-
Notifications
You must be signed in to change notification settings - Fork 932
Add LFM2.5-VL export with CUDA/AOTI backend #18823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,112 +1,102 @@ | ||
| from typing import Optional | ||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| from executorch.examples.models.llama.attention import ForwardOptions | ||
| from executorch.examples.models.llama.feed_forward import FeedForward | ||
|
|
||
| from executorch.examples.models.llama.norm import RMSNorm | ||
| from torch import nn | ||
|
|
||
|
|
||
| class ShortConv(nn.Module): | ||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| L_cache: int = 3, | ||
| bias: bool = False, | ||
| device: Optional[torch.device] = None, | ||
| dtype: Optional[torch.dtype] = None, | ||
| ): | ||
| """Depthwise short convolution with dual state management. | ||
|
|
||
| Supports two modes: | ||
| 1. State-as-IO: caller passes conv_state in and receives new state back. | ||
| Required for AOTI which cannot re-trace mutable buffer mutations. | ||
| 2. Internal buffer: uses register_buffer + copy_() for XNNPack/portable | ||
| backends where mutable buffers are handled natively. | ||
| """ | ||
|
|
||
| def __init__(self, dim: int, L_cache: int = 3, *, bias: bool = False) -> None: | ||
| super().__init__() | ||
| assert L_cache == 3, f"Manual depthwise conv only supports L_cache=3, got {L_cache}" | ||
| self.dim = dim | ||
| self.L_cache = L_cache | ||
| self.device = device | ||
| self.dtype = dtype | ||
| self.bias = bias | ||
|
|
||
| self.conv = nn.Conv1d( | ||
| dim, | ||
| dim, | ||
| kernel_size=L_cache, | ||
| padding=0, ## we don't need padding since we handle it manually | ||
| groups=dim, | ||
| bias=bias, | ||
| ) | ||
|
|
||
| conv_state = torch.zeros( | ||
| 1, ## batch size is assumed to be 1 for now | ||
| dim, | ||
| L_cache - 1, | ||
| device="cpu", | ||
| ) | ||
| self.register_buffer("conv_state", conv_state) | ||
|
|
||
| ## better performance in Executorch with separate projections | ||
| self.conv = nn.Conv1d(dim, dim, kernel_size=L_cache, padding=0, groups=dim, bias=bias) | ||
| self.B_proj = nn.Linear(dim, dim, bias=bias) | ||
| self.C_proj = nn.Linear(dim, dim, bias=bias) | ||
| self.x_proj = nn.Linear(dim, dim, bias=bias) | ||
|
|
||
| self.out_proj = nn.Linear(dim, dim, bias=bias) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| batch_size, seqlen, dim = x.size() | ||
| assert batch_size == 1, "batch_size must be 1" | ||
|
|
||
| B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
| C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
| x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
|
|
||
| Bx = B * x # (batch_size, dim, seq_len) | ||
| self.register_buffer( | ||
| "conv_state", | ||
| torch.zeros(1, dim, L_cache - 1), | ||
| ) | ||
|
|
||
| ## This is where we handle padding | ||
| ## By default, the conv_state is initialized to 0. | ||
| # So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to | ||
| ## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding. | ||
| ## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0. | ||
| Bx = torch.cat( | ||
| [self.conv_state, Bx], dim=-1 | ||
| ) # (batch_size, dim, seq_len + L_cache - 1) | ||
| def forward( | ||
| self, x: torch.Tensor, conv_state: torch.Tensor | None = None | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if conv_state is None: | ||
| conv_state = self.conv_state | ||
|
|
||
| ## Update the conv_state | ||
| new_conv_state = Bx[ | ||
| ..., -(self.L_cache - 1) : | ||
| ] # (batch_size, dim, L_cache - 1) | ||
| with torch.no_grad(): | ||
| self.conv_state.copy_(new_conv_state) | ||
| B = self.B_proj(x).transpose(-1, -2) | ||
| C = self.C_proj(x).transpose(-1, -2) | ||
| x = self.x_proj(x).transpose(-1, -2) | ||
|
|
||
| conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len) | ||
| y = C * conv_out # (batch_size, dim, seq_len) | ||
| Bx = torch.cat([conv_state, B * x], dim=-1) | ||
| new_conv_state = Bx[..., -(self.L_cache - 1) :] | ||
|
|
||
| y = y.transpose(-1, -2) # (batch_size, seq_len, dim) | ||
| y = y.contiguous() # (batch_size, seq_len, dim) | ||
| y = self.out_proj(y) # (batch_size, seq_len, dim) | ||
| return y | ||
| # Manual depthwise conv — Triton has no template for nn.Conv1d | ||
| # with groups=dim and dynamic sequence length. | ||
| w = self.conv.weight[:, 0, :] | ||
| conv_out = Bx[..., :-2] * w[:, 0:1] + Bx[..., 1:-1] * w[:, 1:2] + Bx[..., 2:] * w[:, 2:3] | ||
|
|
||
| def reset_cache(self): | ||
| self.conv_state.zero_() | ||
| y = self.out_proj((C * conv_out).transpose(-1, -2).contiguous()) | ||
| return y, new_conv_state | ||
|
|
||
|
|
||
| class ShortConvBlock(nn.Module): | ||
| def __init__(self, dim: int, hidden_dim: int, norm_eps: float): | ||
| def __init__(self, dim: int, hidden_dim: int, norm_eps: float, layer_idx: int = -1) -> None: | ||
| super().__init__() | ||
| self.L_cache = 3 # hardcode 3 for now | ||
| self.conv = ShortConv(dim, self.L_cache, bias=False) | ||
| self.layer_idx = layer_idx | ||
| self.conv = ShortConv(dim, L_cache=3, bias=False) | ||
| self.feed_forward = FeedForward(dim, hidden_dim) | ||
| self.ffn_norm = RMSNorm(dim, norm_eps) | ||
| # use attention_norm norm instead of operator_norm to unify with TransformerBlock | ||
| self.attention_norm = RMSNorm(dim, norm_eps) | ||
|
|
||
| def forward( | ||
| self, | ||
| x, | ||
| freqs_cos=None, | ||
| freqs_sin=None, | ||
| _unused_attn_options: Optional[ForwardOptions] = None, | ||
| ): # x: 1xN | ||
| h = self.conv.forward(self.attention_norm(x)) | ||
| x: torch.Tensor, | ||
| freqs_cos: torch.Tensor | None = None, | ||
| freqs_sin: torch.Tensor | None = None, | ||
| attn_options: ForwardOptions | None = None, | ||
| ) -> tuple[torch.Tensor, dict]: | ||
| # State-as-IO: read from attn_options if provided (CUDA/AOTI path) | ||
| conv_state = None | ||
| if attn_options is not None: | ||
| conv_states = attn_options.get("conv_states") | ||
| if conv_states is not None: | ||
| conv_state = conv_states.get(self.layer_idx) | ||
|
|
||
| h, new_conv_state = self.conv(self.attention_norm(x), conv_state) | ||
| h = x + h | ||
| out = h + self.feed_forward(self.ffn_norm(h)) | ||
| return out, None | ||
|
|
||
| def reset_cache(self): | ||
| self.conv.reset_cache() | ||
| # Write back state | ||
| update: dict = {} | ||
| if attn_options is not None and "conv_states" in attn_options: | ||
| if conv_state is not None: | ||
| conv_state.copy_(new_conv_state) | ||
| states = dict(attn_options["conv_states"]) | ||
| states[self.layer_idx] = new_conv_state | ||
| update["conv_states"] = states | ||
|
Comment on lines
+88
to
+93
|
||
| else: | ||
| # XNNPack/portable path: persist via internal buffer | ||
| with torch.no_grad(): | ||
| self.conv.conv_state.copy_(new_conv_state) | ||
|
|
||
| return out, update | ||
|
|
||
| def reset_cache(self) -> None: | ||
| self.conv.conv_state.zero_() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from executorch.examples.models.lfm2_5_vl.convert_weights import convert_weights | ||
| from executorch.examples.models.lfm2_5_vl.model import Lfm2p5VlModel | ||
|
|
||
| __all__ = [ | ||
| "convert_weights", | ||
| "Lfm2p5VlModel", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| { | ||
| "dim": 2048, | ||
| "ffn_dim_multiplier": 1, | ||
| "hidden_dim": 8192, | ||
| "n_heads": 32, | ||
| "n_kv_heads": 8, | ||
| "n_layers": 16, | ||
| "norm_eps": 1e-5, | ||
| "rope_theta": 1000000.0, | ||
| "use_scaled_rope": false, | ||
| "vocab_size": 65536, | ||
| "use_hf_rope": true, | ||
| "use_qk_norm": true, | ||
| "qk_norm_before_rope": true, | ||
| "layer_types": [ | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv" | ||
| ] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| { | ||
| "dim": 1024, | ||
| "ffn_dim_multiplier": 1, | ||
| "hidden_dim": 4608, | ||
| "n_heads": 16, | ||
| "n_kv_heads": 8, | ||
| "n_layers": 16, | ||
| "norm_eps": 1e-5, | ||
| "rope_theta": 1000000.0, | ||
| "use_scaled_rope": false, | ||
| "vocab_size": 65536, | ||
| "use_hf_rope": true, | ||
| "use_qk_norm": true, | ||
| "qk_norm_before_rope": true, | ||
| "layer_types": [ | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv", | ||
| "full_attention", | ||
| "conv" | ||
| ] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """Convert LFM2.5-VL text decoder weights from HuggingFace to ET format.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from executorch.examples.models.checkpoint import get_mapped_key | ||
| from safetensors.torch import load_file | ||
|
|
||
| _LFM2_5_VL_TO_META: dict[str, str] = { | ||
| "model.language_model.embed_tokens.weight": "tok_embeddings.weight", | ||
| "model.language_model.embedding_norm.weight": "norm.weight", | ||
| "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", | ||
| "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", | ||
| "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", | ||
| "model.language_model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight", | ||
| "model.language_model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight", | ||
| "model.language_model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight", | ||
| "model.language_model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight", | ||
| "model.language_model.layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", | ||
| "model.language_model.layers.{}.feed_forward.w1.weight": "layers.{}.feed_forward.w1.weight", | ||
| "model.language_model.layers.{}.feed_forward.w2.weight": "layers.{}.feed_forward.w2.weight", | ||
| "model.language_model.layers.{}.feed_forward.w3.weight": "layers.{}.feed_forward.w3.weight", | ||
| "model.language_model.layers.{}.conv.conv.weight": "layers.{}.conv.conv.weight", | ||
| "model.language_model.layers.{}.conv.out_proj.weight": "layers.{}.conv.out_proj.weight", | ||
| "model.language_model.lm_head.weight": "output.weight", | ||
| } | ||
|
|
||
| _IN_PROJ_SPLITS = ("B_proj", "C_proj", "x_proj") | ||
|
|
||
|
|
||
| def lfm2_5_vl_to_meta(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | ||
| """Extract and remap language model weights from a full VL state dict.""" | ||
| converted: dict[str, torch.Tensor] = {} | ||
|
|
||
| for key, value in state_dict.items(): | ||
| if not key.startswith("model.language_model."): | ||
| continue | ||
|
|
||
| try: | ||
| new_key = get_mapped_key(key, _LFM2_5_VL_TO_META) | ||
| except Exception: | ||
| new_key = key.removeprefix("model.language_model.") | ||
|
|
||
| if new_key.endswith(".conv.in_proj.weight"): | ||
| for name, chunk in zip(_IN_PROJ_SPLITS, torch.chunk(value, 3, dim=0)): | ||
| converted[new_key.replace("in_proj", name)] = chunk | ||
| else: | ||
| converted[new_key] = value | ||
|
|
||
| if "output.weight" not in converted: | ||
| converted["output.weight"] = converted["tok_embeddings.weight"] | ||
|
|
||
| return converted | ||
|
|
||
|
|
||
| def convert_weights(input_dir: str, output_file: str) -> None: | ||
| sd = load_file(str(Path(input_dir) / "model.safetensors")) | ||
| sd = lfm2_5_vl_to_meta(sd) | ||
| torch.save(sd, output_file) | ||
| print(f"Saved {len(sd)} tensors to {output_file}") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description="Convert LFM2.5-VL weights to ET format.") | ||
| parser.add_argument("input_dir", help="Directory containing model.safetensors.") | ||
| parser.add_argument("output", help="Output .pt checkpoint path.") | ||
| args = parser.parse_args() | ||
| convert_weights(args.input_dir, args.output) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShortConv.forwardimplements the convolution manually usingself.conv.weight, but it ignoresself.conv.biaswhenbias=True. This makes thebiasargument silently incorrect. Either add the bias term toconv_outor enforcebias=False(e.g., via an assertion and/or by removing the parameter) to avoid surprising behavior.