Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,63 @@ def __init__(self, **kwargs):
olmo3_32b_config = transformers.Olmo3Config(**olmo3_32b_dict)


qwen3_vl_4b_dict = {
"architectures": ["Qwen3VLForConditionalGeneration"],
"image_token_id": 151655,
"model_type": "qwen3_vl",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"pad_token_id": None,
"rms_norm_eps": 1e-06,
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_theta": 5000000,
"rope_type": "default",
},
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
"tie_word_embeddings": True,
"transformers_version": "5.8.0",
"video_token_id": 151656,
"vision_config": {
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl_vision",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2560,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
}
qwen3_vl_4b_config = PTConfig(**qwen3_vl_4b_dict)


# {maxtext model name: hf model config}
HF_MODEL_CONFIGS = {
"gemma2-2b": gemma2_2b_config,
Expand All @@ -1612,6 +1669,7 @@ def __init__(self, **kwargs):
"qwen3-14b": qwen3_14b_config,
"qwen3-14b-base": qwen3_14b_config,
"qwen3-32b": qwen3_32b_config,
"qwen3-vl-4b": qwen3_vl_4b_config,
"llama3.1-8b": llama31_8b_config,
"llama3.1-8b-Instruct": llama31_8b_config,
"llama3.1-70b": llama31_70b_config,
Expand Down
57 changes: 57 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,62 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
return shapes


def QWEN3_VL_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3-VL weights path and weights shape."""
text_shapes = QWEN_HF_WEIGHTS_TO_SHAPE(config["text_config"])
vl_shapes = {}
for k, v in text_shapes.items():
if k.startswith("model."):
new_k = k.replace("model.", "model.language_model.")
vl_shapes[new_k] = v
else:
vl_shapes[k] = v

vision_config = config["vision_config"]
v_depth = vision_config["depth"]
v_hidden_size = vision_config["hidden_size"]
v_intermediate_size = vision_config["intermediate_size"]
v_out_hidden_size = vision_config["out_hidden_size"]

vl_shapes["model.visual.patch_embed.proj.weight"] = [v_hidden_size, 3, 2, 16, 16]
vl_shapes["model.visual.patch_embed.proj.bias"] = [v_hidden_size]
vl_shapes["model.visual.pos_embed.weight"] = [vision_config["num_position_embeddings"], v_hidden_size]

for i in range(v_depth):
prefix = f"model.visual.blocks.{i}"
vl_shapes[f"{prefix}.norm1.weight"] = [v_hidden_size]
vl_shapes[f"{prefix}.norm1.bias"] = [v_hidden_size]
vl_shapes[f"{prefix}.norm2.weight"] = [v_hidden_size]
vl_shapes[f"{prefix}.norm2.bias"] = [v_hidden_size]
vl_shapes[f"{prefix}.attn.qkv.weight"] = [3 * v_hidden_size, v_hidden_size]
vl_shapes[f"{prefix}.attn.qkv.bias"] = [3 * v_hidden_size]
vl_shapes[f"{prefix}.attn.proj.weight"] = [v_hidden_size, v_hidden_size]
vl_shapes[f"{prefix}.attn.proj.bias"] = [v_hidden_size]
vl_shapes[f"{prefix}.mlp.linear_fc1.weight"] = [v_intermediate_size, v_hidden_size]
vl_shapes[f"{prefix}.mlp.linear_fc1.bias"] = [v_intermediate_size]
vl_shapes[f"{prefix}.mlp.linear_fc2.weight"] = [v_hidden_size, v_intermediate_size]
vl_shapes[f"{prefix}.mlp.linear_fc2.bias"] = [v_hidden_size]

deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"model.visual.deepstack_merger_list.{merger_idx}"
vl_shapes[f"{prefix}.norm.weight"] = [v_intermediate_size]
vl_shapes[f"{prefix}.norm.bias"] = [v_intermediate_size]
vl_shapes[f"{prefix}.linear_fc1.weight"] = [v_intermediate_size, v_intermediate_size]
vl_shapes[f"{prefix}.linear_fc1.bias"] = [v_intermediate_size]
vl_shapes[f"{prefix}.linear_fc2.weight"] = [v_out_hidden_size, v_intermediate_size]
vl_shapes[f"{prefix}.linear_fc2.bias"] = [v_out_hidden_size]

vl_shapes["model.visual.merger.norm.weight"] = [v_hidden_size]
vl_shapes["model.visual.merger.norm.bias"] = [v_hidden_size]
vl_shapes["model.visual.merger.linear_fc1.weight"] = [v_intermediate_size, v_intermediate_size]
vl_shapes["model.visual.merger.linear_fc1.bias"] = [v_intermediate_size]
vl_shapes["model.visual.merger.linear_fc2.weight"] = [v_out_hidden_size, v_intermediate_size]
vl_shapes["model.visual.merger.linear_fc2.bias"] = [v_out_hidden_size]

return vl_shapes


# {maxtext model name: {hf weight name: hf shape}}
HF_SHAPE = {
"gemma2-2b": GEMMA2_HF_WEIGHTS_TO_SHAPE,
Expand All @@ -1126,6 +1182,7 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
"qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-vl-4b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE,
"llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"llama3.1-8b-Instruct": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"llama3.1-70b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
Expand Down
196 changes: 196 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3663,6 +3663,200 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
return hooks


def QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen3-VL weight paths."""
mapping = {}

n_layers_text = config["text_config"]["num_hidden_layers"]
text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
config={"num_hidden_layers": n_layers_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
)

def replace_prefix(val):
if isinstance(val, list):
return [replace_prefix(v) for v in val]
elif isinstance(val, str):
return val.replace("model.", "model.language_model.")
return val

for key, value in text_mapping.items():
mapping[key] = replace_prefix(value)

vision_config = config["vision_config"]
n_vision_layers = vision_config["depth"]

mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-kernel"] = "model.visual.patch_embed.proj.weight"
mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-bias"] = "model.visual.patch_embed.proj.bias"

mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-pos_embed_interpolate-pos_embed"] = (
"model.visual.pos_embed.weight"
)

for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-blocks_{i}"
hf_prefix = f"model.visual.blocks.{i}"

mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight"
mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias"
mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight"
mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias"

mapping[
(
f"{prefix}-attn-attn-query-kernel",
f"{prefix}-attn-attn-key-kernel",
f"{prefix}-attn-attn-value-kernel",
)
] = f"{hf_prefix}.attn.qkv.weight"
mapping[
(
f"{prefix}-attn-attn-query-bias",
f"{prefix}-attn-attn-key-bias",
f"{prefix}-attn-attn-value-bias",
)
] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight"
mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias"

mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight"
mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias"
mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight"
mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias"

deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-merger_{merger_idx}"
hf_prefix = f"model.visual.deepstack_merger_list.{merger_idx}"

mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.norm.weight"
mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.norm.bias"
mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.linear_fc1.weight"
mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.linear_fc1.bias"
mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.linear_fc2.weight"
mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.linear_fc2.bias"

mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-ln_q-scale"] = "model.visual.merger.norm.weight"
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-ln_q-bias"] = "model.visual.merger.norm.bias"
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-kernel"] = "model.visual.merger.linear_fc1.weight"
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-bias"] = "model.visual.merger.linear_fc1.bias"
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-kernel"] = "model.visual.merger.linear_fc2.weight"
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-bias"] = "model.visual.merger.linear_fc2.bias"

return mapping


def QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen3-VL."""
mapping = {}

n_layers_text = config["text_config"]["num_hidden_layers"]
text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
config={"num_hidden_layers": n_layers_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
saving_to_hf=saving_to_hf,
)
mapping.update(text_hooks)

vision_config = config["vision_config"]
n_vision_layers = vision_config["depth"]
hidden_size = vision_config["hidden_size"]

def reshape_kernel_vision(input_tensor, target_shape):
"""Reshape kernel for vision layers."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)

def reshape_conv3d_patch_embed(input_tensor, target_shape):
"""Reshape 3D conv patch embedding weight."""
if saving_to_hf:
return input_tensor.transpose(4, 3, 0, 1, 2)
else:
return input_tensor.transpose(2, 3, 4, 1, 0)

def process_qkv_vision(input_tensor, target_shape=None):
"""Handles composite_mt_key: maxtext (query, key, value) <-> hf (qkv)."""
if saving_to_hf:
q, k, v = input_tensor
q_hf = q.reshape(hidden_size, hidden_size).T
k_hf = k.reshape(hidden_size, hidden_size).T
v_hf = v.reshape(hidden_size, hidden_size).T
return np.concatenate([q_hf, k_hf, v_hf], axis=0)
else:
q_hf = input_tensor[:hidden_size, :]
k_hf = input_tensor[hidden_size : 2 * hidden_size, :]
v_hf = input_tensor[2 * hidden_size :, :]
q_mt = q_hf.T.reshape(target_shape[0])
k_mt = k_hf.T.reshape(target_shape[1])
v_mt = v_hf.T.reshape(target_shape[2])
return np.stack([q_mt, k_mt, v_mt], axis=-1)

def process_qkv_bias_vision(input_tensor, target_shape=None):
"""Handles composite_mt_key: maxtext (query_bias, key_bias, value_bias) <-> hf (qkv_bias)."""
if saving_to_hf:
qb, kb, vb = input_tensor
qb_hf = qb.reshape(hidden_size)
kb_hf = kb.reshape(hidden_size)
vb_hf = vb.reshape(hidden_size)
return np.concatenate([qb_hf, kb_hf, vb_hf], axis=0)
else:
qb_hf = input_tensor[:hidden_size]
kb_hf = input_tensor[hidden_size : 2 * hidden_size]
vb_hf = input_tensor[2 * hidden_size :]
qb_mt = qb_hf.reshape(target_shape[0])
kb_mt = kb_hf.reshape(target_shape[1])
vb_mt = vb_hf.reshape(target_shape[2])
return np.stack([qb_mt, kb_mt, vb_mt], axis=-1)

def reshape_vision_attn_out(input_tensor, target_shape):
"""Reshape vision attention output projection."""
if saving_to_hf:
return input_tensor.reshape(hidden_size, hidden_size).T
else:
return input_tensor.T.reshape(target_shape)

mapping["params-vision_encoder-Qwen3VLVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed

for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-blocks_{i}"

mapping[
(
f"{prefix}-attn-attn-query-kernel",
f"{prefix}-attn-attn-key-kernel",
f"{prefix}-attn-attn-value-kernel",
)
] = process_qkv_vision
mapping[
(
f"{prefix}-attn-attn-query-bias",
f"{prefix}-attn-attn-key-bias",
f"{prefix}-attn-attn-value-bias",
)
] = process_qkv_bias_vision

mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out

mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision

deepstack_indexes = vision_config.get("deepstack_visual_indexes", [5, 11, 17])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3VLVisionEncoder_0-merger_{merger_idx}"
mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision

mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision
mapping["params-vision_encoder-Qwen3VLVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision

return mapping


# {maxtext model name: {maxtext weight name: hf weight name}}
PARAM_MAPPING = {
"gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
Expand All @@ -3689,6 +3883,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
Expand Down Expand Up @@ -3739,6 +3934,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def save_config_file(
):
"""Saves the model configuration file(config.json)."""
if jax.process_index() == 0:
config.architectures = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]]
if config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
config.architectures = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]]
if output_dir_final.startswith("hf://"):
max_logging.log(f" Serializing {file_name} to memory for Hugging Face Hub upload...")
json_string = config.to_json_string()
Expand Down
Loading
Loading