Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tuple inputs in extract_submodel #2267

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions coremltools/converters/mil/debugging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,14 @@ def validate_inputs(func, input_vars):
reachable_vars.add(op.outputs[0])

for op in func.operations:
if all([x in reachable_vars for x in op.inputs.values()]):
input_values = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @smpanaro for putting this PR!
In order to get this PR merged,
could you also add an unittest for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. Just added one (and also fixed one that was failing because of the opset_version change).

for v in op.inputs.values():
if isinstance(v, (list, tuple)):
input_values.extend(v)
else:
input_values.append(v)

if all([x in reachable_vars for x in input_values]):
reachable_vars.update(op.outputs)

for out in func.outputs:
Expand Down Expand Up @@ -170,6 +177,6 @@ def replace_inputs(func, input_vars):
PASS_REGISTRY["common::dead_code_elimination"](prog)

prog.skip_all_passes = True
submodel = ct.convert(prog, convert_to=backend, compute_units=model.compute_unit)
submodel = ct.convert(prog, convert_to=backend, compute_units=model.compute_unit, minimum_deployment_target=func.opset_version)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related but fixes an error I get with my example conversion script:

RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: Error compiling model: "compiler error: Error reading protobuf spec. validator error: Description of multiarray feature 'x_cast_fp16' has FLOAT16 dataType, which is only valid in specification version >= 7. This model has version 6".

Happy to put this up as a separate PR if that's preferred.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a correct fix :)


return submodel