diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index 371bb24ce4..ef54005b83 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -1586,6 +1586,63 @@ def __init__(self, **kwargs): olmo3_32b_config = transformers.Olmo3Config(**olmo3_32b_dict) +qwen3_vl_4b_dict = { + "architectures": ["Qwen3VLForConditionalGeneration"], + "image_token_id": 151655, + "model_type": "qwen3_vl", + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "pad_token_id": None, + "rms_norm_eps": 1e-06, + "rope_parameters": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_theta": 5000000, + "rope_type": "default", + }, + "tie_word_embeddings": True, + "use_cache": True, + "vocab_size": 151936, + }, + "tie_word_embeddings": True, + "transformers_version": "5.8.0", + "video_token_id": 151656, + "vision_config": { + "deepstack_visual_indexes": [5, 11, 17], + "depth": 24, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1024, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "qwen3_vl_vision", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 2560, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, +} +qwen3_vl_4b_config = PTConfig(**qwen3_vl_4b_dict) + + # {maxtext model name: hf model config} HF_MODEL_CONFIGS = { "gemma2-2b": gemma2_2b_config, @@ -1612,6 +1669,7 @@ def __init__(self, **kwargs): "qwen3-14b": qwen3_14b_config, "qwen3-14b-base": qwen3_14b_config, "qwen3-32b": qwen3_32b_config, + "qwen3-vl-4b": qwen3_vl_4b_config, "llama3.1-8b": llama31_8b_config, "llama3.1-8b-Instruct": llama31_8b_config, "llama3.1-70b": llama31_70b_config, diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 35e0ea5a99..1849c971c6 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -1105,6 +1105,62 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): return shapes +def QWEN3_VL_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen3-VL weights path and weights shape.""" + text_shapes = QWEN_HF_WEIGHTS_TO_SHAPE(config["text_config"]) + vl_shapes = {} + for k, v in text_shapes.items(): + if k.startswith("model."): + new_k = k.replace("model.", "model.language_model.") + vl_shapes[new_k] = v + else: + vl_shapes[k] = v + + vision_config = config["vision_config"] + v_depth = vision_config["depth"] + v_hidden_size = vision_config["hidden_size"] + v_intermediate_size = vision_config["intermediate_size"] + v_out_hidden_size = vision_config["out_hidden_size"] + + vl_shapes["model.visual.patch_embed.proj.weight"] = [v_hidden_size, 3, 2, 16, 16] + vl_shapes["model.visual.patch_embed.proj.bias"] = [v_hidden_size] + vl_shapes["model.visual.pos_embed.weight"] = [vision_config["num_position_embeddings"], v_hidden_size] + + for i in range(v_depth): + prefix = f"model.visual.blocks.{i}" + vl_shapes[f"{prefix}.norm1.weight"] = [v_hidden_size] + vl_shapes[f"{prefix}.norm1.bias"] = [v_hidden_size] + vl_shapes[f"{prefix}.norm2.weight"] = [v_hidden_size] + vl_shapes[f"{prefix}.norm2.bias"] = [v_hidden_size] + vl_shapes[f"{prefix}.attn.qkv.weight"] = [3 * v_hidden_size, v_hidden_size] + vl_shapes[f"{prefix}.attn.qkv.bias"] = [3 * v_hidden_size] + vl_shapes[f"{prefix}.attn.proj.weight"] = [v_hidden_size, v_hidden_size] + vl_shapes[f"{prefix}.attn.proj.bias"] = [v_hidden_size] + vl_shapes[f"{prefix}.mlp.linear_fc1.weight"] = [v_intermediate_size, v_hidden_size] + vl_shapes[f"{prefix}.mlp.linear_fc1.bias"] = [v_intermediate_size] + vl_shapes[f"{prefix}.mlp.linear_fc2.weight"] = [v_hidden_size, v_intermediate_size] + vl_shapes[f"{prefix}.mlp.linear_fc2.bias"] = [v_hidden_size] + + deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17]) + for merger_idx, _ in enumerate(deepstack_indexes): + prefix = f"model.visual.deepstack_merger_list.{merger_idx}" + vl_shapes[f"{prefix}.norm.weight"] = [v_intermediate_size] + vl_shapes[f"{prefix}.norm.bias"] = [v_intermediate_size] + vl_shapes[f"{prefix}.linear_fc1.weight"] = [v_intermediate_size, v_intermediate_size] + vl_shapes[f"{prefix}.linear_fc1.bias"] = [v_intermediate_size] + vl_shapes[f"{prefix}.linear_fc2.weight"] = [v_out_hidden_size, v_intermediate_size] + vl_shapes[f"{prefix}.linear_fc2.bias"] = [v_out_hidden_size] + + vl_shapes["model.visual.merger.norm.weight"] = [v_hidden_size] + vl_shapes["model.visual.merger.norm.bias"] = [v_hidden_size] + vl_shapes["model.visual.merger.linear_fc1.weight"] = [v_intermediate_size, v_intermediate_size] + vl_shapes["model.visual.merger.linear_fc1.bias"] = [v_intermediate_size] + vl_shapes["model.visual.merger.linear_fc2.weight"] = [v_out_hidden_size, v_intermediate_size] + vl_shapes["model.visual.merger.linear_fc2.bias"] = [v_out_hidden_size] + + return vl_shapes + + # {maxtext model name: {hf weight name: hf shape}} HF_SHAPE = { "gemma2-2b": GEMMA2_HF_WEIGHTS_TO_SHAPE, @@ -1126,6 +1182,7 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-vl-4b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE, "llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE, "llama3.1-8b-Instruct": LLAMA31_HF_WEIGHTS_TO_SHAPE, "llama3.1-70b": LLAMA31_HF_WEIGHTS_TO_SHAPE, diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 2767afefbc..0f2aae1c03 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -3663,6 +3663,200 @@ def pad_hf_embedding_layer(input_tensor, target_shape): return hooks +def QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """Returns mapping from MaxText to HuggingFace Qwen3-VL weight paths.""" + mapping = {} + + n_layers_text = config["text_config"]["num_hidden_layers"] + text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING( + config={"num_hidden_layers": n_layers_text}, + maxtext_config=maxtext_config, + scan_layers=scan_layers, + ) + + def replace_prefix(val): + if isinstance(val, list): + return [replace_prefix(v) for v in val] + elif isinstance(val, str): + return val.replace("model.", "model.language_model.") + return val + + for key, value in text_mapping.items(): + mapping[key] = replace_prefix(value) + + vision_config = config["vision_config"] + n_vision_layers = vision_config["depth"] + + mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-kernel"] = "model.visual.patch_embed.proj.weight" + mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-bias"] = "model.visual.patch_embed.proj.bias" + + mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-pos_embed_interpolate-pos_embed"] = ( + "model.visual.pos_embed.weight" + ) + + for i in range(n_vision_layers): + prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-blocks_{i}" + hf_prefix = f"model.visual.blocks.{i}" + + mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight" + mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias" + mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight" + mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias" + + mapping[ + ( + f"{prefix}-attn-attn-query-kernel", + f"{prefix}-attn-attn-key-kernel", + f"{prefix}-attn-attn-value-kernel", + ) + ] = f"{hf_prefix}.attn.qkv.weight" + mapping[ + ( + f"{prefix}-attn-attn-query-bias", + f"{prefix}-attn-attn-key-bias", + f"{prefix}-attn-attn-value-bias", + ) + ] = f"{hf_prefix}.attn.qkv.bias" + mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight" + mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias" + + mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight" + mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias" + mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight" + mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias" + + deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17]) + for merger_idx, _ in enumerate(deepstack_indexes): + prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-merger_{merger_idx}" + hf_prefix = f"model.visual.deepstack_merger_list.{merger_idx}" + + mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.norm.weight" + mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.norm.bias" + mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.linear_fc1.weight" + mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.linear_fc1.bias" + mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.linear_fc2.weight" + mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.linear_fc2.bias" + + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-ln_q-scale"] = "model.visual.merger.norm.weight" + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-ln_q-bias"] = "model.visual.merger.norm.bias" + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-kernel"] = "model.visual.merger.linear_fc1.weight" + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-bias"] = "model.visual.merger.linear_fc1.bias" + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-kernel"] = "model.visual.merger.linear_fc2.weight" + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-bias"] = "model.visual.merger.linear_fc2.bias" + + return mapping + + +def QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """Creates parameter transformation functions for Qwen3-VL.""" + mapping = {} + + n_layers_text = config["text_config"]["num_hidden_layers"] + text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN( + config={"num_hidden_layers": n_layers_text}, + maxtext_config=maxtext_config, + scan_layers=scan_layers, + saving_to_hf=saving_to_hf, + ) + mapping.update(text_hooks) + + vision_config = config["vision_config"] + n_vision_layers = vision_config["depth"] + hidden_size = vision_config["hidden_size"] + + def reshape_kernel_vision(input_tensor, target_shape): + """Reshape kernel for vision layers.""" + if saving_to_hf: + flipped_target_shape = np.flip(np.array(target_shape)) + return input_tensor.reshape(flipped_target_shape).T + else: + return input_tensor.T.reshape(target_shape) + + def reshape_conv3d_patch_embed(input_tensor, target_shape): + """Reshape 3D conv patch embedding weight.""" + if saving_to_hf: + return input_tensor.transpose(4, 3, 0, 1, 2) + else: + return input_tensor.transpose(2, 3, 4, 1, 0) + + def process_qkv_vision(input_tensor, target_shape=None): + """Handles composite_mt_key: maxtext (query, key, value) <-> hf (qkv).""" + if saving_to_hf: + q, k, v = input_tensor + q_hf = q.reshape(hidden_size, hidden_size).T + k_hf = k.reshape(hidden_size, hidden_size).T + v_hf = v.reshape(hidden_size, hidden_size).T + return np.concatenate([q_hf, k_hf, v_hf], axis=0) + else: + q_hf = input_tensor[:hidden_size, :] + k_hf = input_tensor[hidden_size : 2 * hidden_size, :] + v_hf = input_tensor[2 * hidden_size :, :] + q_mt = q_hf.T.reshape(target_shape[0]) + k_mt = k_hf.T.reshape(target_shape[1]) + v_mt = v_hf.T.reshape(target_shape[2]) + return np.stack([q_mt, k_mt, v_mt], axis=-1) + + def process_qkv_bias_vision(input_tensor, target_shape=None): + """Handles composite_mt_key: maxtext (query_bias, key_bias, value_bias) <-> hf (qkv_bias).""" + if saving_to_hf: + qb, kb, vb = input_tensor + qb_hf = qb.reshape(hidden_size) + kb_hf = kb.reshape(hidden_size) + vb_hf = vb.reshape(hidden_size) + return np.concatenate([qb_hf, kb_hf, vb_hf], axis=0) + else: + qb_hf = input_tensor[:hidden_size] + kb_hf = input_tensor[hidden_size : 2 * hidden_size] + vb_hf = input_tensor[2 * hidden_size :] + qb_mt = qb_hf.reshape(target_shape[0]) + kb_mt = kb_hf.reshape(target_shape[1]) + vb_mt = vb_hf.reshape(target_shape[2]) + return np.stack([qb_mt, kb_mt, vb_mt], axis=-1) + + def reshape_vision_attn_out(input_tensor, target_shape): + """Reshape vision attention output projection.""" + if saving_to_hf: + return input_tensor.reshape(hidden_size, hidden_size).T + else: + return input_tensor.T.reshape(target_shape) + + mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed + + for i in range(n_vision_layers): + prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-blocks_{i}" + + mapping[ + ( + f"{prefix}-attn-attn-query-kernel", + f"{prefix}-attn-attn-key-kernel", + f"{prefix}-attn-attn-value-kernel", + ) + ] = process_qkv_vision + mapping[ + ( + f"{prefix}-attn-attn-query-bias", + f"{prefix}-attn-attn-key-bias", + f"{prefix}-attn-attn-value-bias", + ) + ] = process_qkv_bias_vision + + mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out + + mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision + mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision + + deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17]) + for merger_idx, _ in enumerate(deepstack_indexes): + prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-merger_{merger_idx}" + mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision + mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision + + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision + mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision + + return mapping + + # {maxtext model name: {maxtext weight name: hf weight name}} PARAM_MAPPING = { "gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -3689,6 +3883,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -3739,6 +3934,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 7e9b92b486..625197504b 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -372,7 +372,8 @@ def save_config_file( ): """Saves the model configuration file(config.json).""" if jax.process_index() == 0: - config.architectures = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]] + if config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + config.architectures = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]] if output_dir_final.startswith("hf://"): max_logging.log(f" Serializing {file_name} to memory for Hugging Face Hub upload...") json_string = config.to_json_string() diff --git a/src/maxtext/configs/models/qwen3-vl-4b.yml b/src/maxtext/configs/models/qwen3-vl-4b.yml new file mode 100644 index 0000000000..ed67cc3e47 --- /dev/null +++ b/src/maxtext/configs/models/qwen3-vl-4b.yml @@ -0,0 +1,56 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for Qwen/Qwen3-VL-4B-Instruct + +# Core Architectural Parameters +decoder_block: "qwen3" +base_emb_dim: 2560 +base_mlp_dim: 9728 +base_num_query_heads: 32 +base_num_kv_heads: 8 +base_num_decoder_layers: 36 +head_dim: 128 +mlp_activations: ["silu", "linear"] +vocab_size: 151936 +normalization_layer_epsilon: 1.0e-6 +use_qk_norm: true +logits_via_embedding: true +normalize_embedding_logits: false + +# RoPE Settings +rope_max_timescale: 5000000 + +# General Model Settings +enable_dropout: false + +# Vision Encoder Configuration +# Based on HuggingFace AutoConfig for Qwen/Qwen3-VL-4B-Instruct +use_multimodal: true +image_size_for_vit: 768 +hidden_size_for_vit: 1024 +intermediate_size_for_vit: 4096 +num_attention_heads_for_vit: 16 +num_hidden_layers_for_vit: 24 +num_channels_for_vit: 3 +patch_size_for_vit: 16 +temporal_patch_size_for_vit: 2 +spatial_merge_size_for_vit: 2 +out_hidden_size_for_vit: 2560 +num_position_embeddings_for_vit: 2304 +deepstack_visual_indexes_for_vit: [5, 11, 17] + +# MRoPE Settings +use_mrope: true +mrope_section: [24, 20, 20] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 08a7d3e930..c8db90b33a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -261,6 +261,7 @@ class ProfilerType(str, Enum): "qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-480b-a35b", + "qwen3-vl-4b", "qwen3-next-80b-a3b", "qwen3-omni-30b-a3b", "qwen3-custom-30b-a3b", @@ -3163,6 +3164,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b", ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index ab7673d1d4..b08a059938 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -788,7 +788,7 @@ def init_rotary_embedding(self): rope_type = self.rope_type rope_use_scale = self.config.rope_use_scale if self.is_vision: - if self.config.model_name.startswith("qwen3-omni") or self.config.model_name.startswith("qwen3.5"): + if self.config.model_name.startswith("qwen3"): rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding( hidden_size=self.config.hidden_size_for_vit, num_attention_heads=self.config.num_attention_heads_for_vit, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 6bca4b98c7..2ff3ff3315 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -728,6 +728,7 @@ def _apply_embedding( "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b", ]: @@ -742,7 +743,7 @@ def _apply_embedding( raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") if video_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=video_embeddings, diff --git a/src/maxtext/layers/encoders.py b/src/maxtext/layers/encoders.py index 0db94ad9d6..93d44452d4 100644 --- a/src/maxtext/layers/encoders.py +++ b/src/maxtext/layers/encoders.py @@ -80,6 +80,16 @@ def _setup_vision_encoder_layers(self): ) setattr(self, projector_name, qwen3_5_vision.Qwen3_5MoeVisionProjector(config=self.config, rngs=self.rngs)) return encoder_name, projector_name + elif self.config.model_name in ["qwen3-vl-4b"]: + from maxtext.models import qwen3_vl_vision # pylint: disable=import-outside-toplevel + + encoder_name = "Qwen3VLVisionEncoder_0" + projector_name = "Qwen3VLVisionProjector_0" + setattr( + self, encoder_name, qwen3_vl_vision.Qwen3VLVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs) + ) + setattr(self, projector_name, qwen3_vl_vision.Qwen3VLVisionProjector(config=self.config, rngs=self.rngs)) + return encoder_name, projector_name else: raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet") diff --git a/src/maxtext/models/qwen3_vl_vision.py b/src/maxtext/models/qwen3_vl_vision.py new file mode 100644 index 0000000000..a9afef1fbb --- /dev/null +++ b/src/maxtext/models/qwen3_vl_vision.py @@ -0,0 +1,32 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3 VL Vision model tower NNX subclasses.""" + +from maxtext.models.qwen3 import Qwen3OmniMoeVisionEncoder, Qwen3OmniMoeVisionProjector + + +class Qwen3VLVisionEncoder(Qwen3OmniMoeVisionEncoder): + """Subclass of Qwen3OmniMoeVisionEncoder for Qwen3-VL models. + + Inherits all core vision tower layers (patch embedding, position embedding, + rotary embeddings, attention, and transformer blocks) without modification. + """ + + +class Qwen3VLVisionProjector(Qwen3OmniMoeVisionProjector): + """Subclass of Qwen3OmniMoeVisionProjector for Qwen3-VL models. + + Inherits the final projection/merger layers without modification. + """ diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 7c99800f2a..3d7c5ee72e 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -44,7 +44,7 @@ def preprocess_mm_data(config): images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")] processor_outputs = preprocess_mm_data_llama4(images) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel processor_outputs = preprocess_mm_data_qwen3_omni(config) @@ -68,7 +68,7 @@ def preprocess_image_for_training(image, config): from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel return preprocess_mm_data_llama4(image) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel return preprocess_mm_data_qwen3_omni_for_training(image, config) @@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel return get_image_offsets_llama4(processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel return get_mm_offsets_qwen3_omni(config, processor_output) @@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel return reformat_prompt_llama4(prompt, image_placeholder, num_images) - elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel return reformat_prompt_qwen3_omni( @@ -137,7 +137,7 @@ def reformat_response(response, model_name): elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]: formatted_response = f"{response}" return formatted_response - elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: formatted_response = f"{response}<|im_end|>" return formatted_response else: @@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None): from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel return add_extra_tokens_for_images_llama4(tokens, processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output) @@ -181,7 +181,7 @@ def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_seque from maxtext.multimodal.processor_llama4 import get_dummy_image_shape_for_init_llama4 # pylint: disable=import-outside-toplevel image_shape = get_dummy_image_shape_for_init_llama4(batch_size, num_image_per_sequence) - elif model_name.startswith("qwen3-omni-30b-a3b") or model_name.startswith("qwen3.5"): + elif model_name.startswith("qwen3-omni") or model_name.startswith("qwen3-vl") or model_name.startswith("qwen3.5"): from maxtext.multimodal.processor_qwen3_omni import get_dummy_image_shape_for_init_qwen3_omni # pylint: disable=import-outside-toplevel image_shape = get_dummy_image_shape_for_init_qwen3_omni(batch_size) @@ -222,7 +222,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool = from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel tokens = QwenTokens(config) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index 57102f6a47..4ad0f1d9c4 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -63,6 +63,7 @@ "qwen3-8b": "Qwen/Qwen3-8B", "qwen3-14b": "Qwen/Qwen3-14B", "qwen3-32b": "Qwen/Qwen3-32B", + "qwen3-vl-4b": "Qwen/Qwen3-VL-4B-Instruct", "llama3.1-8b": "meta-llama/Llama-3.1-8B", "llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct", "llama3.1-70b-Instruct": "meta-llama/Llama-3.1-70B-Instruct",