diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index d8241469b65..4bb863e54cb 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -231,6 +231,7 @@ def build_args_parser() -> argparse.ArgumentParser: "vulkan_8w", "tosa_8a8w", "ethosu_8a8w", + "ethosu_16a8w", "vgf_8a8w", "vgf_16a8w", ], @@ -845,9 +846,19 @@ def get_quantizer_and_quant_params(llm_config): llm_config.quantization.pt2e_quantize.value ) quantizers.append(coreml_quantizer) + arm_quantize_scope = llm_config.quantization.quantize_scope.value + if ( + arm_quantize_scope == "full" + and llm_config.backend.vgf.enabled + and llm_config.backend.vgf.quantize_scope.value != "full" + ): + arm_quantize_scope = llm_config.backend.vgf.quantize_scope.value + if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize: tosa_quantizer = get_tosa_quantizer( - llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value + llm_config.backend.tosa.version, + llm_config.quantization.pt2e_quantize.value, + arm_quantize_scope, ) quantizers.append(tosa_quantizer) if llm_config.backend.ethosu.enabled and llm_config.quantization.pt2e_quantize: @@ -855,7 +866,9 @@ def get_quantizer_and_quant_params(llm_config): llm_config.backend.ethosu.target, llm_config.backend.ethosu.system_config, llm_config.backend.ethosu.memory_mode, + llm_config.backend.ethosu.extra_flags, llm_config.quantization.pt2e_quantize.value, + arm_quantize_scope, ) quantizers.append(ethosu_quantizer) if llm_config.backend.vgf.enabled and llm_config.quantization.pt2e_quantize: @@ -1054,6 +1067,7 @@ def _to_edge_and_lower_llama_arm( llm_config.backend.ethosu.target, llm_config.backend.ethosu.system_config, llm_config.backend.ethosu.memory_mode, + llm_config.backend.ethosu.extra_flags, ) ) modelname = f"ethosu_{modelname}" diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 2f3d10f54f8..2b01fdca5a9 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum): vulkan_8w = "vulkan_8w" tosa_8a8w = "tosa_8a8w" ethosu_8a8w = "ethosu_8a8w" + ethosu_16a8w = "ethosu_16a8w" vgf_8a8w = "vgf_8a8w" vgf_16a8w = "vgf_16a8w" @@ -386,6 +387,11 @@ class SpinQuant(str, Enum): native = "native" +class QuantizeScope(str, Enum): + full = "full" + linear = "linear" + + @dataclass class QuantizationConfig: """ @@ -403,6 +409,9 @@ class QuantizationConfig: use_spin_quant: Which spin quant mode to use. If unspecified, don't use spin quant. use_qat: Whether the checkpoint is quantization-awarely trained. + quantize_scope: Scope for Arm PT2E quantization. "full" quantizes the + full supported graph, while "linear" limits quantization to + torch.nn.Linear modules. calibration_tasks: Tasks for GPTQ calibration from lm_eval. calibration_limit: Number of samples used for calibration from lm_eval. calibration_seq_length: Sequence length for GPTQ calibration from lm_eval. @@ -427,6 +436,7 @@ class QuantizationConfig: group_size: Optional[int] = None use_spin_quant: Optional[SpinQuant] = None use_qat: bool = False + quantize_scope: QuantizeScope = QuantizeScope.full calibration_tasks: Optional[List[str]] = None calibration_limit: Optional[int] = None calibration_seq_length: Optional[int] = None @@ -587,6 +597,7 @@ class EthosUConfig: target: str = "ethos-u85-128" # Default target, can be overridden. memory_mode: str = "default" system_config: str = "default" + extra_flags: List[str] = field(default_factory=list) class VgfQuantizeScope(str, Enum): @@ -832,7 +843,9 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.backend.vgf.quantize_scope = VgfQuantizeScope( args.vgf_quantize_scope ) - + llm_config.quantization.quantize_scope = QuantizeScope( + args.vgf_quantize_scope + ) # TorchAoKernels if any( hasattr(args, a) diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 0abb4b663fb..19c0b7fdcfb 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -252,6 +252,7 @@ def get_ethosu_partitioner( target: str, system_config: Optional[str] = None, memory_mode: Optional[str] = None, + extra_flags: Optional[List[str]] = None, ): from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec from executorch.backends.arm.ethosu.partitioner import EthosUPartitioner @@ -260,6 +261,7 @@ def get_ethosu_partitioner( target, system_config=None if system_config == "default" else system_config, memory_mode=None if memory_mode == "default" else memory_mode, + extra_flags=extra_flags, ) return EthosUPartitioner(compile_spec) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index cd70610ee11..8aa2277860d 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -323,7 +323,7 @@ def get_vulkan_quantizer(pt2e_quantize: str): return quantizer -def get_tosa_quantizer(version: str, pt2e_quantize: str): +def get_tosa_quantizer(version: str, pt2e_quantize: str, quantize_scope: str): from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, @@ -335,34 +335,76 @@ def get_tosa_quantizer(version: str, pt2e_quantize: str): quantizer = TOSAQuantizer(compile_spec) if pt2e_quantize == "tosa_8a8w": - quantizer.set_global(get_symmetric_quantization_config()) + quantization_config = get_symmetric_quantization_config() else: raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + _apply_arm_quantize_scope( + quantizer, + quantization_config=quantization_config, + quantize_scope=quantize_scope, + backend_name="TOSA", + ) return quantizer def get_ethosu_quantizer( - target: str, system_config: str, memory_mode: str, pt2e_quantize: str + target: str, + system_config: str, + memory_mode: str, + extra_flags: Optional[List[str]], + pt2e_quantize: str, + quantize_scope: str, ): from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec from executorch.backends.arm.quantizer.arm_quantizer import ( EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, ) - compile_spec = EthosUCompileSpec(target, system_config, memory_mode) + compile_spec = EthosUCompileSpec( + target, + system_config, + memory_mode, + extra_flags=extra_flags, + ) quantizer = EthosUQuantizer(compile_spec) if pt2e_quantize == "ethosu_8a8w": - quantizer.set_global(get_symmetric_quantization_config()) + quantization_config = get_symmetric_quantization_config() + elif pt2e_quantize == "ethosu_16a8w": + quantization_config = get_symmetric_a16w8_quantization_config() else: raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + _apply_arm_quantize_scope( + quantizer, + quantization_config=quantization_config, + quantize_scope=quantize_scope, + backend_name="Ethos-U", + ) return quantizer +def _apply_arm_quantize_scope( + quantizer, + *, + quantization_config, + quantize_scope: str, + backend_name: str, +): + if quantize_scope == "full": + quantizer.set_global(quantization_config) + elif quantize_scope == "linear": + quantizer.set_module_type(torch.nn.Linear, quantization_config) + else: + raise ValueError( + f"Unsupported {backend_name} quantization scope {quantize_scope}" + ) + + def get_vgf_quantizer( compile_spec: Optional[str], compiler_flags: Optional[List[str]], @@ -392,11 +434,10 @@ def get_vgf_quantizer( else: raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") - if quantize_scope == "full": - quantizer.set_global(quantization_config) - elif quantize_scope == "linear": - quantizer.set_module_type(torch.nn.Linear, quantization_config) - else: - raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}") - + _apply_arm_quantize_scope( + quantizer, + quantization_config=quantization_config, + quantize_scope=quantize_scope, + backend_name="VGF", + ) return quantizer