Skip to content

Commit

Permalink
fix: correct usage of Ops.select for Base.ifelse (#332)
Browse files Browse the repository at this point in the history
* fix: correct usage of Ops.select for Base.ifelse

* Update src/TracedRNumber.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] authored Dec 6, 2024
1 parent b7de1e6 commit 1bb0000
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,19 @@ function Base.ifelse(
@nospecialize(x::TracedRNumber{T1}),
@nospecialize(y::TracedRNumber{T2})
) where {T1,T2}
return TracedRNumber{promote_type(T1, T2)}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
),
)
@warn "`ifelse` with different element-types in Reactant works by promoting the \
element-type to the common type. This is semantically different from the \
behavior of `ifelse` in Base. Use with caution" maxlog = 1
T = promote_type(T1, T2)
return ifelse(pred, promote_to(TracedRNumber{T}, x), promote_to(TracedRNumber{T}, y))
end

function Base.ifelse(
@nospecialize(pred::TracedRNumber{Bool}),
@nospecialize(x::TracedRNumber{T}),
@nospecialize(y::TracedRNumber{T})
) where {T}
return Ops.select(pred, x, y)
end

for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
Expand Down
13 changes: 13 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,16 @@ end

@test @jit(f_row_major(x_ra)) f_row_major(x)
end

@testset "ifelse" begin
@test 1.0 ==
@jit ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
@test @jit(
ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
) isa ConcreteRNumber{Float64}
@test 0.0f0 ==
@jit ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
@test @jit(
ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0))
) isa ConcreteRNumber{Float32}
end

0 comments on commit 1bb0000

Please sign in to comment.