diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index efb1bdd38b4..a2ced502ac5 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -15,13 +15,29 @@ ) -def convert_unsafe(neutron_converter, tflite_model, cctx, queue): +def _build_compilation_context(compilation_opts): + """Build a CompilationContext from a plain dict of options.""" + cctx = neutron_converter.CompilationContext() + cctx.targetOpts = neutron_converter.getNeutronTarget(compilation_opts["target"]) + cctx.compilationOpts.minNumOpsPerGraph = compilation_opts["minNumOpsPerGraph"] + cctx.compilationOpts.excludeGraphPasses = compilation_opts["excludeGraphPasses"] + cctx.compilationOpts.fetchConstantsToSRAM = compilation_opts["fetchConstantsToSRAM"] + cctx.compilationOpts.dumpKernelSelectionCode = compilation_opts[ + "dumpKernelSelectionCode" + ] + if hasattr(cctx.compilationOpts, "useNewFlowNeutronC"): + cctx.compilationOpts.useNewFlowNeutronC = compilation_opts["useNewFlowNeutronC"] + return cctx + + +def convert_unsafe(tflite_model, compilation_opts, queue): """ - Run neutron_converter on given tflite_model with compilation context cctx. + Run neutron_converter on given tflite_model with the provided compilation options. This routine is supposed to run in a separate process. If properly finished, the output queue contains the converted model, otherwise the neutron_converter exits and the output queue is empty. """ + cctx = _build_compilation_context(compilation_opts) model_converted = neutron_converter.convertModel(list(tflite_model), cctx) queue.put(model_converted) @@ -84,16 +100,14 @@ def convert( # Neutron converter crashes if we provide invalid target -> verify. self.verify_target(target) - cctx = neutron_converter.CompilationContext() - cctx.targetOpts = neutron_converter.getNeutronTarget(target) - cctx.compilationOpts.minNumOpsPerGraph = 1 - cctx.compilationOpts.excludeGraphPasses = ( - "HoistSliceAboveTranspose,MergeTranspose" - ) - cctx.compilationOpts.fetchConstantsToSRAM = fetch_constants_to_sram - cctx.compilationOpts.dumpKernelSelectionCode = self.dump_kernel_selection_code - if hasattr(cctx.compilationOpts, "useNewFlowNeutronC"): - cctx.compilationOpts.useNewFlowNeutronC = use_new_flow_neutron_c + compilation_opts = { + "target": target, + "minNumOpsPerGraph": 1, + "excludeGraphPasses": "HoistSliceAboveTranspose,MergeTranspose", + "fetchConstantsToSRAM": fetch_constants_to_sram, + "dumpKernelSelectionCode": self.dump_kernel_selection_code, + "useNewFlowNeutronC": use_new_flow_neutron_c, + } # Try to use multiprocessing for isolation, but fall back to direct execution # if the environment doesn't support it (e.g., in sandcastle/build environments) @@ -104,7 +118,7 @@ def convert( process = multiprocessing.Process( target=convert_unsafe, - args=(neutron_converter, tflite_model, cctx, queue), + args=(tflite_model, compilation_opts, queue), ) process.start() process.join() # waits until the subprocess is complete @@ -116,12 +130,13 @@ def convert( model_converted = queue.get() process.close() - except (EOFError, OSError) as e: + except (EOFError, OSError, TypeError) as e: # Multiprocessing failed (likely due to environment restrictions) # Fall back to direct execution logging.warning( f"Multiprocessing not available ({e}), running neutron converter directly" ) + cctx = _build_compilation_context(compilation_opts) model_converted = neutron_converter.convertModel(list(tflite_model), cctx) if self.dump_kernel_selection_code: self._rename_partition_kernel_selection_file(delegation_tag) diff --git a/backends/nxp/tests/BUCK b/backends/nxp/tests/BUCK index c16d6267425..2e793e81d96 100644 --- a/backends/nxp/tests/BUCK +++ b/backends/nxp/tests/BUCK @@ -112,6 +112,20 @@ fbcode_target(_kind = python_pytest, ], ) +fbcode_target(_kind = python_pytest, + name = "test_neutron_converter_manager", + srcs = [ + "generic_tests/test_neutron_converter_manager.py", + ], + deps = [ + "//executorch/backends/nxp:neutron_sdk", + "//executorch/exir:lib", + ":executorch_pipeline", + ":models", + "fbsource//third-party/pypi/pytest-mock:pytest-mock", # @manual + ], +) + fbcode_target(_kind = python_pytest, name = "test_integration", srcs = [ diff --git a/backends/nxp/tests/generic_tests/test_neutron_converter_manager.py b/backends/nxp/tests/generic_tests/test_neutron_converter_manager.py index c00cc507bbc..d26a1444b3b 100644 --- a/backends/nxp/tests/generic_tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/generic_tests/test_neutron_converter_manager.py @@ -4,9 +4,9 @@ # LICENSE file in the root directory of this source tree. import multiprocessing +import pickle import torch -from eiq_neutron_sdk.neutron_converter.neutron_converter import CompilationContext from executorch import exir from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, @@ -69,7 +69,28 @@ def test_neutron_converter_with_experimental_mlir_flow(mocker): model, input_shape, use_new_flow_neutron_c=True ).exported_program() - compilation_context = process_spy.call_args.kwargs["args"][2] - assert isinstance(compilation_context, CompilationContext) - if hasattr(compilation_context.compilationOpts, "useNewFlowNeutronC"): - assert compilation_context.compilationOpts.useNewFlowNeutronC + compilation_opts = process_spy.call_args.kwargs["args"][1] + assert isinstance(compilation_opts, dict) + assert compilation_opts["useNewFlowNeutronC"] is True + + +def test_convert_unsafe_args_are_picklable(mocker): + """Verify that all args passed to multiprocessing.Process are picklable. + + The subprocess uses forkserver/spawn in some environments, which requires + all Process args to be serializable via pickle. + """ + model = LinearModule(True) + input_shape = (1, 1, 32, 32) + + process_spy = mocker.spy(multiprocessing, "Process") + to_quantized_edge_program(model, input_shape).exported_program() + + args = process_spy.call_args.kwargs["args"] + for i, arg in enumerate(args): + try: + pickle.dumps(arg) + except (pickle.PicklingError, TypeError) as e: + raise AssertionError( + f"Process arg at index {i} ({type(arg).__name__}) is not picklable: {e}" + )