From 84334b976029c4c233e3ef286d2dd76227639d14 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 20:53:36 +0000 Subject: [PATCH 1/2] Initial plan From 06d338384136d58b8788f7fc1742c527d28b7f75 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:21:05 +0000 Subject: [PATCH 2/2] Refactor FalconMamba: replace class name checks with generic hasattr(config, 'use_mambapy') Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --- onnx_diagnostic/tasks/text_generation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 25b4d29c..4fec3ac2 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -13,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" - # FalconMambaConfig: use_mambapy + # Mamba models (e.g. FalconMambaConfig) use use_mambapy instead of num_attention_heads if hasattr(config, "text_config"): # The model is probably of mixture of models used only for text. config = config.text_config @@ -25,7 +25,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: "hidden_size", "vocab_size", ) - if config.__class__.__name__ == "FalconMambaConfig": + if hasattr(config, "use_mambapy"): check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( num_hidden_layers=min(config.num_hidden_layers, nhl()), @@ -54,7 +54,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: return kwargs -def _get_input_falcon_mamba( +def _get_input_mamba( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, @@ -157,8 +157,8 @@ def get_inputs( seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) - if config is not None and config.__class__.__name__ == "FalconMambaConfig": - res = _get_input_falcon_mamba( + if config is not None and hasattr(config, "use_mambapy"): + res = _get_input_mamba( model=model, config=config, dummy_max_token_id=dummy_max_token_id, @@ -343,7 +343,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "hidden_size", ) - if config.__class__.__name__ == "FalconMambaConfig": + if hasattr(config, "use_mambapy"): check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( batch_size=2,