Skip to content

Commit

Permalink
torch.log(): cast int arguments to float32 (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcuenca authored Nov 7, 2023
1 parent 4dc2ba8 commit b2f7190
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
5 changes: 4 additions & 1 deletion coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5124,7 +5124,10 @@ def reciprocal(context, node):
@register_torch_op
def log(context, node):
inputs = _get_inputs(context, node, expected=1)
context.add(mb.log(x=inputs[0], name=node.name))
x = inputs[0]
if types.is_int(x.dtype):
x = mb.cast(x=x, dtype="fp32")
context.add(mb.log(x=x, name=node.name))


@register_torch_op(torch_alias=["round"])
Expand Down
27 changes: 27 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5546,6 +5546,33 @@ def test_elementwise_numerically_stable(
rand_range=(20, 100),
)

@pytest.mark.parametrize(
"compute_unit, backend, dtype",
itertools.product(
compute_units,
backends,
[np.int32, np.float32],
),
)
def test_log_dtype(
self, compute_unit, backend, dtype
):
SHAPE = (2, 3)

input_data = np.random.randint(1, 100, SHAPE).astype(dtype)
input_data = torch.from_numpy(input_data)
model = ModuleWrapper(torch.log)
converter_input_type = [TensorType(shape=SHAPE, dtype=dtype)]

self.run_compare_torch(
input_data,
model,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
converter_input_type=converter_input_type
)


class TestAtan2(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down

0 comments on commit b2f7190

Please sign in to comment.