From baf48bb705cf6f368fc0da91ccba558093e496c7 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Fri, 10 Apr 2026 21:11:06 +0000 Subject: [PATCH] Add LFM2.5-VL export with CUDA/AOTI backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Export LFM2.5-VL (450M and 1.6B) as a multi-method PTE with three methods: vision_encoder, token_embedding, and text_decoder, all delegated to the CUDA/AOTI backend. Key changes: - examples/models/lfm2_5_vl/: new model, weight converter, and export script for LFM2.5-VL on CUDA - examples/models/lfm2/short_conv.py: dual state management — state-as-IO for CUDA/AOTI (via attn_options["conv_states"]) with register_buffer fallback for XNNPack/portable backends - examples/models/llama/llama_transformer.py: pass layer_idx to ShortConvBlock for per-layer conv state keying - exir/emit/_emitter.py: copy CUDA tensor storage to CPU before ctypes pointer read to prevent segfault during serialization Tested on NVIDIA B300: 333-400 decode tok/s, 435-454 prefill tok/s, correct coherent generation on text-only and vision-language prompts. Also compatible with llama_main C++ runner. --- examples/models/lfm2/short_conv.py | 142 +++++----- examples/models/lfm2_5_vl/__init__.py | 13 + .../config/lfm2_5_vl_1_6b_config.json | 33 +++ .../config/lfm2_5_vl_450m_config.json | 33 +++ examples/models/lfm2_5_vl/convert_weights.py | 81 ++++++ examples/models/lfm2_5_vl/export_lfm2_5_vl.py | 254 ++++++++++++++++++ examples/models/lfm2_5_vl/model.py | 141 ++++++++++ examples/models/llama/llama_transformer.py | 1 + exir/emit/_emitter.py | 5 +- 9 files changed, 626 insertions(+), 77 deletions(-) create mode 100644 examples/models/lfm2_5_vl/__init__.py create mode 100644 examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json create mode 100644 examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json create mode 100644 examples/models/lfm2_5_vl/convert_weights.py create mode 100644 examples/models/lfm2_5_vl/export_lfm2_5_vl.py create mode 100644 examples/models/lfm2_5_vl/model.py diff --git a/examples/models/lfm2/short_conv.py b/examples/models/lfm2/short_conv.py index ae04580d6c6..11ad83c81e4 100644 --- a/examples/models/lfm2/short_conv.py +++ b/examples/models/lfm2/short_conv.py @@ -1,112 +1,102 @@ -from typing import Optional +from __future__ import annotations import torch from executorch.examples.models.llama.attention import ForwardOptions from executorch.examples.models.llama.feed_forward import FeedForward - from executorch.examples.models.llama.norm import RMSNorm from torch import nn class ShortConv(nn.Module): - def __init__( - self, - dim: int, - L_cache: int = 3, - bias: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): + """Depthwise short convolution with dual state management. + + Supports two modes: + 1. State-as-IO: caller passes conv_state in and receives new state back. + Required for AOTI which cannot re-trace mutable buffer mutations. + 2. Internal buffer: uses register_buffer + copy_() for XNNPack/portable + backends where mutable buffers are handled natively. + """ + + def __init__(self, dim: int, L_cache: int = 3, *, bias: bool = False) -> None: super().__init__() + assert L_cache == 3, f"Manual depthwise conv only supports L_cache=3, got {L_cache}" self.dim = dim self.L_cache = L_cache - self.device = device - self.dtype = dtype - self.bias = bias - - self.conv = nn.Conv1d( - dim, - dim, - kernel_size=L_cache, - padding=0, ## we don't need padding since we handle it manually - groups=dim, - bias=bias, - ) - - conv_state = torch.zeros( - 1, ## batch size is assumed to be 1 for now - dim, - L_cache - 1, - device="cpu", - ) - self.register_buffer("conv_state", conv_state) - ## better performance in Executorch with separate projections + self.conv = nn.Conv1d(dim, dim, kernel_size=L_cache, padding=0, groups=dim, bias=bias) self.B_proj = nn.Linear(dim, dim, bias=bias) self.C_proj = nn.Linear(dim, dim, bias=bias) self.x_proj = nn.Linear(dim, dim, bias=bias) - self.out_proj = nn.Linear(dim, dim, bias=bias) - def forward(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seqlen, dim = x.size() - assert batch_size == 1, "batch_size must be 1" - - B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) - C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) - x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) - - Bx = B * x # (batch_size, dim, seq_len) + self.register_buffer( + "conv_state", + torch.zeros(1, dim, L_cache - 1), + ) - ## This is where we handle padding - ## By default, the conv_state is initialized to 0. - # So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to - ## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding. - ## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0. - Bx = torch.cat( - [self.conv_state, Bx], dim=-1 - ) # (batch_size, dim, seq_len + L_cache - 1) + def forward( + self, x: torch.Tensor, conv_state: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if conv_state is None: + conv_state = self.conv_state - ## Update the conv_state - new_conv_state = Bx[ - ..., -(self.L_cache - 1) : - ] # (batch_size, dim, L_cache - 1) - with torch.no_grad(): - self.conv_state.copy_(new_conv_state) + B = self.B_proj(x).transpose(-1, -2) + C = self.C_proj(x).transpose(-1, -2) + x = self.x_proj(x).transpose(-1, -2) - conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len) - y = C * conv_out # (batch_size, dim, seq_len) + Bx = torch.cat([conv_state, B * x], dim=-1) + new_conv_state = Bx[..., -(self.L_cache - 1) :] - y = y.transpose(-1, -2) # (batch_size, seq_len, dim) - y = y.contiguous() # (batch_size, seq_len, dim) - y = self.out_proj(y) # (batch_size, seq_len, dim) - return y + # Manual depthwise conv — Triton has no template for nn.Conv1d + # with groups=dim and dynamic sequence length. + w = self.conv.weight[:, 0, :] + conv_out = Bx[..., :-2] * w[:, 0:1] + Bx[..., 1:-1] * w[:, 1:2] + Bx[..., 2:] * w[:, 2:3] - def reset_cache(self): - self.conv_state.zero_() + y = self.out_proj((C * conv_out).transpose(-1, -2).contiguous()) + return y, new_conv_state class ShortConvBlock(nn.Module): - def __init__(self, dim: int, hidden_dim: int, norm_eps: float): + def __init__(self, dim: int, hidden_dim: int, norm_eps: float, layer_idx: int = -1) -> None: super().__init__() - self.L_cache = 3 # hardcode 3 for now - self.conv = ShortConv(dim, self.L_cache, bias=False) + self.layer_idx = layer_idx + self.conv = ShortConv(dim, L_cache=3, bias=False) self.feed_forward = FeedForward(dim, hidden_dim) self.ffn_norm = RMSNorm(dim, norm_eps) - # use attention_norm norm instead of operator_norm to unify with TransformerBlock self.attention_norm = RMSNorm(dim, norm_eps) def forward( self, - x, - freqs_cos=None, - freqs_sin=None, - _unused_attn_options: Optional[ForwardOptions] = None, - ): # x: 1xN - h = self.conv.forward(self.attention_norm(x)) + x: torch.Tensor, + freqs_cos: torch.Tensor | None = None, + freqs_sin: torch.Tensor | None = None, + attn_options: ForwardOptions | None = None, + ) -> tuple[torch.Tensor, dict]: + # State-as-IO: read from attn_options if provided (CUDA/AOTI path) + conv_state = None + if attn_options is not None: + conv_states = attn_options.get("conv_states") + if conv_states is not None: + conv_state = conv_states.get(self.layer_idx) + + h, new_conv_state = self.conv(self.attention_norm(x), conv_state) h = x + h out = h + self.feed_forward(self.ffn_norm(h)) - return out, None - def reset_cache(self): - self.conv.reset_cache() + # Write back state + update: dict = {} + if attn_options is not None and "conv_states" in attn_options: + if conv_state is not None: + conv_state.copy_(new_conv_state) + states = dict(attn_options["conv_states"]) + states[self.layer_idx] = new_conv_state + update["conv_states"] = states + else: + # XNNPack/portable path: persist via internal buffer + with torch.no_grad(): + self.conv.conv_state.copy_(new_conv_state) + + return out, update + + def reset_cache(self) -> None: + self.conv.conv_state.zero_() diff --git a/examples/models/lfm2_5_vl/__init__.py b/examples/models/lfm2_5_vl/__init__.py new file mode 100644 index 00000000000..f1fe2afba26 --- /dev/null +++ b/examples/models/lfm2_5_vl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.lfm2_5_vl.convert_weights import convert_weights +from executorch.examples.models.lfm2_5_vl.model import Lfm2p5VlModel + +__all__ = [ + "convert_weights", + "Lfm2p5VlModel", +] diff --git a/examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json b/examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json new file mode 100644 index 00000000000..396f7bb7a8a --- /dev/null +++ b/examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json @@ -0,0 +1,33 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 8192, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 16, + "norm_eps": 1e-5, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 65536, + "use_hf_rope": true, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "layer_types": [ + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv" + ] +} diff --git a/examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json b/examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json new file mode 100644 index 00000000000..975ccbccca7 --- /dev/null +++ b/examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json @@ -0,0 +1,33 @@ +{ + "dim": 1024, + "ffn_dim_multiplier": 1, + "hidden_dim": 4608, + "n_heads": 16, + "n_kv_heads": 8, + "n_layers": 16, + "norm_eps": 1e-5, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 65536, + "use_hf_rope": true, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "layer_types": [ + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv" + ] +} diff --git a/examples/models/lfm2_5_vl/convert_weights.py b/examples/models/lfm2_5_vl/convert_weights.py new file mode 100644 index 00000000000..82ccba110ee --- /dev/null +++ b/examples/models/lfm2_5_vl/convert_weights.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Convert LFM2.5-VL text decoder weights from HuggingFace to ET format.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch +from executorch.examples.models.checkpoint import get_mapped_key +from safetensors.torch import load_file + +_LFM2_5_VL_TO_META: dict[str, str] = { + "model.language_model.embed_tokens.weight": "tok_embeddings.weight", + "model.language_model.embedding_norm.weight": "norm.weight", + "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.language_model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight", + "model.language_model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.language_model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight", + "model.language_model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight", + "model.language_model.layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", + "model.language_model.layers.{}.feed_forward.w1.weight": "layers.{}.feed_forward.w1.weight", + "model.language_model.layers.{}.feed_forward.w2.weight": "layers.{}.feed_forward.w2.weight", + "model.language_model.layers.{}.feed_forward.w3.weight": "layers.{}.feed_forward.w3.weight", + "model.language_model.layers.{}.conv.conv.weight": "layers.{}.conv.conv.weight", + "model.language_model.layers.{}.conv.out_proj.weight": "layers.{}.conv.out_proj.weight", + "model.language_model.lm_head.weight": "output.weight", +} + +_IN_PROJ_SPLITS = ("B_proj", "C_proj", "x_proj") + + +def lfm2_5_vl_to_meta(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Extract and remap language model weights from a full VL state dict.""" + converted: dict[str, torch.Tensor] = {} + + for key, value in state_dict.items(): + if not key.startswith("model.language_model."): + continue + + try: + new_key = get_mapped_key(key, _LFM2_5_VL_TO_META) + except Exception: + new_key = key.removeprefix("model.language_model.") + + if new_key.endswith(".conv.in_proj.weight"): + for name, chunk in zip(_IN_PROJ_SPLITS, torch.chunk(value, 3, dim=0)): + converted[new_key.replace("in_proj", name)] = chunk + else: + converted[new_key] = value + + if "output.weight" not in converted: + converted["output.weight"] = converted["tok_embeddings.weight"] + + return converted + + +def convert_weights(input_dir: str, output_file: str) -> None: + sd = load_file(str(Path(input_dir) / "model.safetensors")) + sd = lfm2_5_vl_to_meta(sd) + torch.save(sd, output_file) + print(f"Saved {len(sd)} tensors to {output_file}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert LFM2.5-VL weights to ET format.") + parser.add_argument("input_dir", help="Directory containing model.safetensors.") + parser.add_argument("output", help="Output .pt checkpoint path.") + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/lfm2_5_vl/export_lfm2_5_vl.py b/examples/models/lfm2_5_vl/export_lfm2_5_vl.py new file mode 100644 index 00000000000..f27ddd586ec --- /dev/null +++ b/examples/models/lfm2_5_vl/export_lfm2_5_vl.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export LFM2.5-VL as a multi-method PTE for ExecuTorch with CUDA/AOTI backend. + +All three methods are delegated to the CUDA backend. Conv layer state is +threaded through attn_options as explicit IO; KV cache uses mark_static_address +so AOTI can trace through in-place mutations. + +Methods (D = text hidden dim): + vision_encoder : [1, 3, 512, 512] f32 -> [1, 256, D] f32 + token_embedding : [1, seq_len] i64 -> [1, seq_len, D] f32 + text_decoder : ([1, seq_len, D], [seq_len] i64) -> [1, vocab] f32 + +Usage: + python examples/models/lfm2_5_vl/export_lfm2_5_vl.py \\ + --model_dir LiquidAI/LFM2.5-VL-450M --dtype bf16 +""" + +from __future__ import annotations + +import logging +from argparse import ArgumentParser +from pathlib import Path +import torch +from torch.export import Dim, ExportedProgram +from torch.nn.attention import SDPBackend + +from executorch.backends.cuda.cuda_backend import CudaBackend +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass + +from executorch.examples.models.lfm2_5_vl.model import ( + IMAGE_SIZE, + MAX_SEQ_LEN, + Lfm2p5VlModel, +) + +logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s", +) + +# --------------------------------------------------------------------------- +# Blackwell (sm_103) workaround: torch._inductor maps arch 103 -> "100f" but +# Triton generates PTX targeting sm_103a. Patch to match. +# --------------------------------------------------------------------------- +try: + from torch._inductor.codecache import cuda_compile_utils + + _orig_nvcc_arch = cuda_compile_utils._nvcc_arch_as_compile_option + + def _patched_nvcc_arch() -> str: + arch = cuda_compile_utils.cuda_env.get_cuda_arch() + return "103a" if arch == "103" else _orig_nvcc_arch() + + cuda_compile_utils._nvcc_arch_as_compile_option = _patched_nvcc_arch +except (ImportError, AttributeError): + pass + +_CONFIG_DIR = Path(__file__).parent / "config" + +_DTYPE_MAP: dict[str, torch.dtype] = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def _resolve_params_path(model_dir: str, params: str | None) -> str | None: + if params is not None: + return params + name = model_dir.lower() + if "450m" in name: + return str(_CONFIG_DIR / "lfm2_5_vl_450m_config.json") + if "1.6b" in name or "1_6b" in name: + return str(_CONFIG_DIR / "lfm2_5_vl_1_6b_config.json") + return None + + +# --------------------------------------------------------------------------- +# Per-method export +# --------------------------------------------------------------------------- + + +def _export_image_encoder(lfm2: torch.nn.Module, *, device: str) -> ExportedProgram: + class _Encoder(torch.nn.Module): + def __init__(self, lfm2: torch.nn.Module) -> None: + super().__init__() + self.lfm2 = lfm2 + + def forward(self, images: torch.Tensor) -> torch.Tensor: + return self.lfm2.image_embedding(images) + + example = torch.randint(0, 256, (1, 3, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32, device=device) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + return torch.export.export(_Encoder(lfm2), (example,), strict=False) + + +def _export_text_decoder(lfm2: torch.nn.Module, *, dtype: torch.dtype, device: str) -> ExportedProgram: + from executorch.examples.models.lfm2.short_conv import ShortConvBlock + + conv_indices = [i for i, layer in enumerate(lfm2.text_model.layers) if isinstance(layer, ShortConvBlock)] + dim = lfm2.text_model_args.dim + + class _Decoder(torch.nn.Module): + def __init__( + self, text_model: torch.nn.Module, conv_dim: int, conv_indices: list[int], + *, dtype: torch.dtype, device: str, + ) -> None: + super().__init__() + self.text_model = text_model + self.conv_indices = conv_indices + for idx in conv_indices: + buf = torch.zeros(1, conv_dim, 2, dtype=dtype, device=device) + self.register_buffer(f"conv_state_{idx}", buf, persistent=False) + if not torch.compiler.is_compiling(): + torch._dynamo.mark_static_address(buf) + + def forward(self, embeddings: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + conv_states = { + idx: getattr(self, f"conv_state_{idx}") + for idx in self.conv_indices + } + out = self.text_model(None, {"input_pos": input_pos, "conv_states": conv_states}, embeddings) + if isinstance(out, tuple): + out = out[0] + return out.contiguous() + + seq = 8 + token_dim = Dim("token_dim", min=1, max=MAX_SEQ_LEN - 1) + example_emb = torch.randn(1, seq, dim, dtype=dtype, device=device) + example_pos = torch.arange(seq, dtype=torch.int64, device=device) + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + return torch.export._trace._export( + _Decoder(lfm2.text_model, dim, conv_indices, dtype=dtype, device=device), + (example_emb, example_pos), + dynamic_shapes=({1: token_dim}, {0: token_dim}), + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + +def _export_token_embedding(lfm2: torch.nn.Module, *, device: str) -> ExportedProgram: + embed = lfm2.model_.model.language_model.get_input_embeddings() + token_dim = Dim("token_dim_1", min=1, max=MAX_SEQ_LEN) + example = torch.zeros(1, MAX_SEQ_LEN, dtype=torch.int64, device=device) + with torch.no_grad(): + return torch.export.export(embed, (example,), dynamic_shapes=[{1: token_dim}], strict=False) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def export_all( + model_dir: str, + output: str, + *, + dtype: torch.dtype = torch.bfloat16, + max_seq_len: int = MAX_SEQ_LEN, + params_path: str | None = None, +) -> None: + logging.info("Loading %s...", model_dir) + lfm2_model = Lfm2p5VlModel( + model_dir=model_dir, + max_seq_len=max_seq_len, + max_context_len=max_seq_len, + params_path=params_path, + use_sdpa_with_kv_cache_op=False, + ) + lfm2 = lfm2_model.get_eager_model().to(dtype=dtype, device="cuda") + + # Mark KV cache buffers as static addresses after device migration, + # so AOTI can trace through in-place index_put mutations. + for module in lfm2.text_model.modules(): + for name, buf in module.named_buffers(recurse=False): + if name in ("k_cache", "v_cache"): + torch._dynamo.mark_static_address(buf) + + logging.info("[1/3] Vision encoder") + vision_ep = _export_image_encoder(lfm2, device="cuda") + logging.info("[2/3] Text decoder") + decoder_ep = _export_text_decoder(lfm2, dtype=dtype, device="cuda") + logging.info("[3/3] Token embedding") + token_ep = _export_token_embedding(lfm2, device="cuda") + + programs = {"vision_encoder": vision_ep, "token_embedding": token_ep, "text_decoder": decoder_ep} + partitioners = { + k: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(k)])] + for k in programs + } + metadata = { + "get_max_seq_len": lfm2.text_model_args.max_seq_len, + "get_vocab_size": lfm2.text_model_args.vocab_size, + "use_kv_cache": lfm2.text_model_args.use_kv_cache, + "get_eos_ids": [7], + } + + logging.info("Lowering to Edge IR + CUDA") + et_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioners, + compile_config=EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True), + constant_methods=metadata, + ) + + logging.info("Finalizing ExecuTorch program") + et_program = et_prog.to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + sym_shape_eval_pass={k: ConstraintBasedSymShapeEvalPass() for k in programs}, + ) + ) + + output_path = Path(output) + output_dir = output_path.parent + logging.info("Saving %s", output_path) + with open(output_path, "wb") as f: + et_program.write_to_file(f) + et_program.write_tensor_data_to_file(str(output_dir)) + logging.info("Done — methods: %s", et_program.methods) + + +def main() -> None: + parser = ArgumentParser(description="Export LFM2.5-VL to ExecuTorch (CUDA)") + parser.add_argument("--model_dir", default="LiquidAI/LFM2.5-VL-450M") + parser.add_argument("--dtype", default="bf16", choices=list(_DTYPE_MAP)) + parser.add_argument("--max_seq_len", type=int, default=MAX_SEQ_LEN) + parser.add_argument("--params", default=None) + parser.add_argument("--output", default=None) + args = parser.parse_args() + + dtype = _DTYPE_MAP[args.dtype] + params_path = _resolve_params_path(args.model_dir, args.params) + output = args.output or f"lfm2_5_vl_{args.dtype}_cuda.pte" + + export_all(args.model_dir, output, dtype=dtype, max_seq_len=args.max_seq_len, params_path=params_path) + + +if __name__ == "__main__": + main() diff --git a/examples/models/lfm2_5_vl/model.py b/examples/models/lfm2_5_vl/model.py new file mode 100644 index 00000000000..a952f3f7062 --- /dev/null +++ b/examples/models/lfm2_5_vl/model.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""ExecuTorch-friendly LFM2.5-VL model. Mirrors examples/models/llava/model.py.""" + +from __future__ import annotations + +import json +import math +from pathlib import Path + +import torch +import torch.nn.functional as F +from executorch.examples.models.lfm2_5_vl.convert_weights import lfm2_5_vl_to_meta +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) +from executorch.examples.models.llama.source_transformation.sdpa import ( + replace_sdpa_with_custom_op, +) +from executorch.examples.models.model_base import EagerModelBase +from torch.export import Dim +from transformers import AutoModelForImageTextToText, AutoProcessor + +MAX_SEQ_LEN = 2048 +IMAGE_SIZE = 512 +PATCH_SIZE = 16 +FIXED_H, FIXED_W = 32, 32 + +_DEFAULT_PARAMS = Path(__file__).parent / "config" / "lfm2_5_vl_1_6b_config.json" + + +class Lfm2p5Vl(torch.nn.Module): + def __init__(self, hf_model: AutoModelForImageTextToText, params: ModelArgs) -> None: + super().__init__() + self.model_ = hf_model + self.text_model_args = params + self.text_model = construct_transformer(params) + + if params.use_sdpa_with_kv_cache_op: + self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model) + self.text_model = replace_sdpa_with_custom_op(self.text_model) + + self.text_model.load_state_dict( + state_dict=self._translate_weights(), strict=False, assign=True + ) + self._patch_positional_embeddings() + + def _patch_positional_embeddings(self) -> None: + embeddings = self.model_.model.vision_tower.vision_model.embeddings + orig = embeddings.position_embedding.weight.data + sqrt_n = int(math.sqrt(orig.shape[0])) + + grid = orig.reshape(sqrt_n, sqrt_n, -1).permute(2, 0, 1).unsqueeze(0) + resized = F.interpolate( + grid, size=(FIXED_H, FIXED_W), mode="bilinear", align_corners=False, antialias=True + ) + pe = resized.squeeze(0).permute(1, 2, 0).reshape(FIXED_H * FIXED_W, -1).contiguous() + embeddings.register_buffer("_precomputed_pe", pe, persistent=False) + embeddings.resize_positional_embeddings = lambda *_args, **_kw: embeddings._precomputed_pe + + def _translate_weights(self) -> dict[str, torch.Tensor]: + raw: dict[str, torch.Tensor] = {} + for k, v in self.model_.model.language_model.state_dict().items(): + raw[f"model.language_model.{k}"] = v + for k, v in self.model_.lm_head.state_dict().items(): + raw[f"model.language_model.lm_head.{k}"] = v + return lfm2_5_vl_to_meta(raw) + + def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + return self.model_.model.language_model.get_input_embeddings()(tokens) + + def image_embedding(self, nchw_pixels: torch.Tensor) -> torch.Tensor: + """[B, 3, 512, 512] float32 pixels in [0, 255] -> [B, 256, D].""" + x = (nchw_pixels / 255.0 - 0.5) / 0.5 + + x = x.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE) + x = x.permute(0, 2, 3, 4, 5, 1).reshape(1, FIXED_H * FIXED_W, PATCH_SIZE * PATCH_SIZE * 3) + + out = self.model_.model.vision_tower( + pixel_values=x, + pixel_attention_mask=None, + spatial_shapes=torch.tensor([[FIXED_H, FIXED_W]], dtype=torch.int64, device=x.device), + return_dict=True, + ) + feats = out.last_hidden_state.reshape(-1, FIXED_H, FIXED_W, out.last_hidden_state.shape[-1]) + projected = self.model_.model.multi_modal_projector(feats) + return projected.reshape(1, -1, projected.shape[-1]) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + return self.image_embedding(images) + + +class Lfm2p5VlModel(EagerModelBase): + def __init__( + self, + *, + use_sdpa_with_kv_cache_op: bool = True, + use_kv_cache: bool = True, + max_seq_len: int = MAX_SEQ_LEN, + max_context_len: int = MAX_SEQ_LEN, + model_dir: str = "LiquidAI/LFM2.5-VL-1.6B", + params_path: str | None = None, + ) -> None: + self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op + self.max_context_len = max_context_len + self.max_seq_len = max_seq_len + self.model_dir = model_dir + + resolved = Path(params_path) if params_path else _DEFAULT_PARAMS + params = json.loads(resolved.read_text()) + + self.text_model_args = ModelArgs( + max_batch_size=1, + max_seq_len=max_seq_len, + max_context_len=max_context_len, + use_kv_cache=use_kv_cache, + use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, + enable_dynamic_shape=False, + **params, + ) + + self.hf_model = AutoModelForImageTextToText.from_pretrained( + model_dir, device_map="cpu", torch_dtype=torch.float32 + ) + self.processor = AutoProcessor.from_pretrained(model_dir) + self.tokenizer = self.processor.tokenizer + + def get_eager_model(self) -> torch.nn.Module: + return Lfm2p5Vl(self.hf_model, self.text_model_args).to(dtype=torch.float32) + + def get_example_inputs(self) -> tuple[torch.Tensor, ...]: + return (torch.randint(0, 256, (1, 3, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32),) + + def get_dynamic_shapes(self) -> None: + return None diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cb87995aaf6..3c38039f412 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -351,6 +351,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: dim=model_args.dim, hidden_dim=model_args.hidden_dim, norm_eps=model_args.norm_eps, + layer_idx=layer_id, ) ) elif ( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index a48d88fa224..440c6dd9b4c 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -456,10 +456,13 @@ def _tensor_spec_to_evalue( ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() ) + storage = typing.cast(torch.UntypedStorage, spec.storage) + if spec.allocated_memory != 0 and storage.device.type != "cpu": + storage = storage.cpu() buffer_data = ( bytes( ctypes.cast( - typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), + storage.data_ptr(), ctypes.POINTER(spec_array_type), ).contents )