Skip to content

Commit

Permalink
allow passing input_names and output_names params to onnx convert (#1941
Browse files Browse the repository at this point in the history
)

* allow passing input_names and output_names params to onnx convert

* fix for None onnx_export_kwargs

* fix crash

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 03c445c commit d7152a4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
12 changes: 10 additions & 2 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,15 @@ def export(

# This variable holds the output names of the model.
# If postprocessing is enabled, it will be set to the output names of the postprocessing module.
output_names: Optional[List[str]] = None
if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs:
output_names = onnx_export_kwargs.pop("output_names")
else:
output_names = None

if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs:
input_names = onnx_export_kwargs.pop("input_names")
else:
input_names = ["input"]

if isinstance(postprocessing, nn.Module):
# If a user-specified postprocessing module is provided, we will attach is to the model and not
Expand Down Expand Up @@ -452,7 +460,7 @@ def export(
model=complete_model,
model_input=onnx_input,
onnx_filename=output,
input_names=["input"],
input_names=input_names,
output_names=output_names,
onnx_opset=onnx_export_kwargs.get("opset_version", None),
do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,15 @@ def export(

# This variable holds the output names of the model.
# If postprocessing is enabled, it will be set to the output names of the postprocessing module.
output_names: Optional[List[str]] = None
if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs:
output_names = onnx_export_kwargs.pop("output_names")
else:
output_names = None

if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs:
input_names = onnx_export_kwargs.pop("input_names")
else:
input_names = ["input"]

if isinstance(postprocessing, nn.Module):
# If a user-specified postprocessing module is provided, we will attach is to the model and not
Expand Down Expand Up @@ -438,7 +446,7 @@ def export(
model=complete_model,
model_input=onnx_input,
onnx_filename=output,
input_names=["input"],
input_names=input_names,
output_names=output_names,
onnx_opset=onnx_export_kwargs.get("opset_version", None),
do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True),
Expand Down
12 changes: 10 additions & 2 deletions src/super_gradients/module_interfaces/exportable_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,15 @@ def export(

# This variable holds the output names of the model.
# If postprocessing is enabled, it will be set to the output names of the postprocessing module.
output_names: Optional[List[str]] = None
if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs:
output_names = onnx_export_kwargs.pop("output_names")
else:
output_names = None

if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs:
input_names = onnx_export_kwargs.pop("input_names")
else:
input_names = ["input"]

if isinstance(postprocessing, nn.Module):
# If a user-specified postprocessing module is provided, we will attach is to the model and not
Expand Down Expand Up @@ -412,7 +420,7 @@ def export(
model=complete_model,
model_input=onnx_input,
onnx_filename=output,
input_names=["input"],
input_names=input_names,
output_names=output_names,
onnx_opset=onnx_export_kwargs.get("opset_version", None),
do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True),
Expand Down

0 comments on commit d7152a4

Please sign in to comment.