diff --git a/fms_mo/run_quant.py b/fms_mo/run_quant.py index a497239..4521dca 100644 --- a/fms_mo/run_quant.py +++ b/fms_mo/run_quant.py @@ -214,12 +214,12 @@ def run_fp8(model_args, data_args, opt_args, fp8_args): # Third Party from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier - from llmcompressor.transformers import SparseAutoModelForCausalLM + from transformers import AutoModelForCausalLM logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8") if model_args.task_type == "lm": - model = SparseAutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=model_args.torch_dtype, )