-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.fill_ can not apply after general function #1924
base: main
Are you sure you want to change the base?
Conversation
# update the mapping | ||
node_output = node.outputs[0] | ||
val = tensor_to_node_sequence_mapping[node_input] | ||
tensor_to_node_sequence_mapping[node_output] = val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share your reasoning on why this would be generally applicable for nodes in addition to to
?
Also, what is the reason to delete line del tensor_to_node_sequence_mapping[node_input]
? With that line, it makes tensor_to_node_sequence_mapping[node_output]
to be ordered as latest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason to delete line del tensor_to_node_sequence_mapping[node_input]
with del tensor_to_node_sequence_mapping[node_input]
, if cast source is referenced more than once, ValueError will be raised.
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape)
z = y.to(torch.int32)
w = y.to(torch.int32)
return z, w
stacktrace
Traceback (most recent call last):
File "/Users/ryosukefukatani/work/coremltools/onth11.py", line 21, in <module>
ct.TensorType("x", shape=(ct.RangeDim(), ct.RangeDim())),
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/_converters_entry.py", line 542, in convert
main_pipeline=pass_pipeline,
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 217, in _mil_convert
**kwargs
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 61, in load
specification_version,
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 335, in __init__
p(self.graph)
File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/torchir_passes.py", line 141, in generate_tensor_assignment_ops
raise ValueError("No matching select or slice.")
ValueError: No matching select or slice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share your reasoning on why this would be generally applicable for nodes in addition to to?
We need general solution of #1917
ex.
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape) + 1
y.fill_(0.0)
return y
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape) - 1
y.fill_(0.0)
return y
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape) .flatten()
y.fill_(0.0)
return y
and all other operations.
) | ||
y_cm = ct_model.predict({'x': x})['y'] | ||
|
||
assert((y_cm == np.zeros(shape)).all()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe remove the outmost parentheses to be assert all(y_cm == np.zeros(shape))
?
Resolves #1920