diff --git a/coremltools/converters/mil/debugging_utils.py b/coremltools/converters/mil/debugging_utils.py index 01792f0da..11ded2e93 100644 --- a/coremltools/converters/mil/debugging_utils.py +++ b/coremltools/converters/mil/debugging_utils.py @@ -77,7 +77,14 @@ def validate_inputs(func, input_vars): reachable_vars.add(op.outputs[0]) for op in func.operations: - if all([x in reachable_vars for x in op.inputs.values()]): + input_values = [] + for v in op.inputs.values(): + if isinstance(v, (list, tuple)): + input_values.extend(v) + else: + input_values.append(v) + + if all([x in reachable_vars for x in input_values]): reachable_vars.update(op.outputs) for out in func.outputs: @@ -170,6 +177,6 @@ def replace_inputs(func, input_vars): PASS_REGISTRY["common::dead_code_elimination"](prog) prog.skip_all_passes = True - submodel = ct.convert(prog, convert_to=backend, compute_units=model.compute_unit) + submodel = ct.convert(prog, convert_to=backend, compute_units=model.compute_unit, minimum_deployment_target=func.opset_version) return submodel diff --git a/coremltools/converters/mil/mil/tests/test_debug.py b/coremltools/converters/mil/mil/tests/test_debug.py index 0fa69dbf0..34f15ac4f 100644 --- a/coremltools/converters/mil/mil/tests/test_debug.py +++ b/coremltools/converters/mil/mil/tests/test_debug.py @@ -13,12 +13,12 @@ import coremltools as ct from coremltools.converters.mil import Builder as mb from coremltools.converters.mil.debugging_utils import extract_submodel -from coremltools.converters.mil.mil import get_new_symbol +from coremltools.converters.mil.mil import get_new_symbol, types from coremltools.converters.mil.mil.types.symbolic import is_symbolic from coremltools.converters.mil.testing_utils import get_op_types_in_program -def get_simple_program(): - @mb.program(input_specs=[mb.TensorSpec(shape=(1, 2, 3, 4)),]) +def get_simple_program(opset_vesrion=None): + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 2, 3, 4)),], opset_version=opset_vesrion) def prog(x): x = mb.add(x=x, y=1.2, name="add") x = mb.transpose(x=x, perm=[0, 2, 3, 1]) @@ -227,6 +227,32 @@ def prog(x, y): with pytest.raises(ValueError, match="output sin not reachable from inputs"): submodel = extract_submodel(model, outputs=["sin"], inputs=["mul"]) + def test_extract_submodel_tuple_input_ops(self): + """ + Input graph: + x -> relu ---> sin --- + | | + v v + cos -> concat -> tanh -> output_1 + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 2), dtype=types.fp16)], opset_version=ct.target.iOS16) + def prog(x): + relu = mb.relu(x=x, name="relu") + sin = mb.sin(x=relu, name="sin") + cos = mb.cos(x=relu, name="cos") + concat = mb.concat(values=[sin, cos], axis=1, name="concat") + tanh = mb.tanh(x=concat, name="tanh") + return tanh + + model = ct.convert(prog, convert_to="mlprogram") + submodel = extract_submodel(model, outputs=["tanh"], inputs=["relu"]) + + assert get_op_types_in_program(submodel._mil_program) == ["sin", "cos", "concat", "tanh"] + + outputs = list(submodel._mil_program.functions["main"].outputs) + assert len(outputs) == 1 + assert outputs[0].name == "tanh" + @pytest.mark.parametrize( "compute_unit", [ @@ -267,7 +293,7 @@ def test_extract_submodel_neuralnetwork(self, compute_unit): ) ) def test_extract_submodel_mlprogram(self, compute_unit, store_to_disk): - prog = get_simple_program() + prog = get_simple_program(ct.target.iOS16) model = ct.convert( prog, convert_to="mlprogram",