Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"vulkan_8w",
"tosa_8a8w",
"ethosu_8a8w",
"ethosu_16a8w",
"vgf_8a8w",
"vgf_16a8w",
Comment on lines 231 to 236
],
Expand Down Expand Up @@ -845,17 +846,29 @@ def get_quantizer_and_quant_params(llm_config):
llm_config.quantization.pt2e_quantize.value
)
quantizers.append(coreml_quantizer)
arm_quantize_scope = llm_config.quantization.quantize_scope.value
if (
arm_quantize_scope == "full"
and llm_config.backend.vgf.enabled
and llm_config.backend.vgf.quantize_scope.value != "full"
):
arm_quantize_scope = llm_config.backend.vgf.quantize_scope.value

if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize:
tosa_quantizer = get_tosa_quantizer(
llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value
llm_config.backend.tosa.version,
llm_config.quantization.pt2e_quantize.value,
arm_quantize_scope,
)
quantizers.append(tosa_quantizer)
if llm_config.backend.ethosu.enabled and llm_config.quantization.pt2e_quantize:
ethosu_quantizer = get_ethosu_quantizer(
llm_config.backend.ethosu.target,
llm_config.backend.ethosu.system_config,
llm_config.backend.ethosu.memory_mode,
llm_config.backend.ethosu.extra_flags,
llm_config.quantization.pt2e_quantize.value,
arm_quantize_scope,
)
quantizers.append(ethosu_quantizer)
if llm_config.backend.vgf.enabled and llm_config.quantization.pt2e_quantize:
Expand Down Expand Up @@ -1054,6 +1067,7 @@ def _to_edge_and_lower_llama_arm(
llm_config.backend.ethosu.target,
llm_config.backend.ethosu.system_config,
llm_config.backend.ethosu.memory_mode,
llm_config.backend.ethosu.extra_flags,
)
)
modelname = f"ethosu_{modelname}"
Expand Down
15 changes: 14 additions & 1 deletion extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum):
vulkan_8w = "vulkan_8w"
tosa_8a8w = "tosa_8a8w"
ethosu_8a8w = "ethosu_8a8w"
ethosu_16a8w = "ethosu_16a8w"
vgf_8a8w = "vgf_8a8w"
vgf_16a8w = "vgf_16a8w"

Expand All @@ -386,6 +387,11 @@ class SpinQuant(str, Enum):
native = "native"


class QuantizeScope(str, Enum):
full = "full"
linear = "linear"


@dataclass
class QuantizationConfig:
"""
Expand All @@ -403,6 +409,9 @@ class QuantizationConfig:
use_spin_quant: Which spin quant mode to use. If unspecified, don't use
spin quant.
use_qat: Whether the checkpoint is quantization-awarely trained.
quantize_scope: Scope for Arm PT2E quantization. "full" quantizes the
full supported graph, while "linear" limits quantization to
torch.nn.Linear modules.
calibration_tasks: Tasks for GPTQ calibration from lm_eval.
calibration_limit: Number of samples used for calibration from lm_eval.
calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
Expand All @@ -427,6 +436,7 @@ class QuantizationConfig:
group_size: Optional[int] = None
use_spin_quant: Optional[SpinQuant] = None
use_qat: bool = False
quantize_scope: QuantizeScope = QuantizeScope.full
calibration_tasks: Optional[List[str]] = None
calibration_limit: Optional[int] = None
calibration_seq_length: Optional[int] = None
Expand Down Expand Up @@ -587,6 +597,7 @@ class EthosUConfig:
target: str = "ethos-u85-128" # Default target, can be overridden.
memory_mode: str = "default"
system_config: str = "default"
extra_flags: List[str] = field(default_factory=list)


class VgfQuantizeScope(str, Enum):
Expand Down Expand Up @@ -832,7 +843,9 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope(
args.vgf_quantize_scope
)

llm_config.quantization.quantize_scope = QuantizeScope(
args.vgf_quantize_scope
)
# TorchAoKernels
if any(
hasattr(args, a)
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def get_ethosu_partitioner(
target: str,
system_config: Optional[str] = None,
memory_mode: Optional[str] = None,
extra_flags: Optional[List[str]] = None,
):
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec
from executorch.backends.arm.ethosu.partitioner import EthosUPartitioner
Expand All @@ -260,6 +261,7 @@ def get_ethosu_partitioner(
target,
system_config=None if system_config == "default" else system_config,
memory_mode=None if memory_mode == "default" else memory_mode,
extra_flags=extra_flags,
)

return EthosUPartitioner(compile_spec)
Expand Down
65 changes: 53 additions & 12 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def get_vulkan_quantizer(pt2e_quantize: str):
return quantizer


def get_tosa_quantizer(version: str, pt2e_quantize: str):
def get_tosa_quantizer(version: str, pt2e_quantize: str, quantize_scope: str):
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
Expand All @@ -335,34 +335,76 @@ def get_tosa_quantizer(version: str, pt2e_quantize: str):
quantizer = TOSAQuantizer(compile_spec)

if pt2e_quantize == "tosa_8a8w":
quantizer.set_global(get_symmetric_quantization_config())
quantization_config = get_symmetric_quantization_config()
else:
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")

_apply_arm_quantize_scope(
quantizer,
quantization_config=quantization_config,
quantize_scope=quantize_scope,
backend_name="TOSA",
)
return quantizer


def get_ethosu_quantizer(
target: str, system_config: str, memory_mode: str, pt2e_quantize: str
target: str,
system_config: str,
memory_mode: str,
extra_flags: Optional[List[str]],
pt2e_quantize: str,
quantize_scope: str,
):
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec
from executorch.backends.arm.quantizer.arm_quantizer import (
EthosUQuantizer,
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
)

compile_spec = EthosUCompileSpec(target, system_config, memory_mode)
compile_spec = EthosUCompileSpec(
target,
system_config,
memory_mode,
extra_flags=extra_flags,
)
Comment on lines +366 to +371

quantizer = EthosUQuantizer(compile_spec)

if pt2e_quantize == "ethosu_8a8w":
quantizer.set_global(get_symmetric_quantization_config())
quantization_config = get_symmetric_quantization_config()
elif pt2e_quantize == "ethosu_16a8w":
quantization_config = get_symmetric_a16w8_quantization_config()
else:
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")

_apply_arm_quantize_scope(
quantizer,
quantization_config=quantization_config,
quantize_scope=quantize_scope,
backend_name="Ethos-U",
)
return quantizer


def _apply_arm_quantize_scope(
quantizer,
*,
quantization_config,
quantize_scope: str,
backend_name: str,
):
if quantize_scope == "full":
quantizer.set_global(quantization_config)
elif quantize_scope == "linear":
quantizer.set_module_type(torch.nn.Linear, quantization_config)
else:
raise ValueError(
f"Unsupported {backend_name} quantization scope {quantize_scope}"
)


def get_vgf_quantizer(
compile_spec: Optional[str],
compiler_flags: Optional[List[str]],
Expand Down Expand Up @@ -392,11 +434,10 @@ def get_vgf_quantizer(
else:
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")

if quantize_scope == "full":
quantizer.set_global(quantization_config)
elif quantize_scope == "linear":
quantizer.set_module_type(torch.nn.Linear, quantization_config)
else:
raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}")

_apply_arm_quantize_scope(
quantizer,
quantization_config=quantization_config,
quantize_scope=quantize_scope,
backend_name="VGF",
)
return quantizer
Loading