Skip to content

Commit

Permalink
implement prod resolution and terminal prod argument
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 12, 2023
1 parent 4e15191 commit 49e8652
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 49 deletions.
87 changes: 86 additions & 1 deletion src/prod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ export prod,
PreserveTypeRightProd,
GenericProd,
ProductOf,
LinearizedProductOf
LinearizedProductOf,
TerminalProdArgument,
resolve_prod_strategy

"""
UnspecifiedProd
Expand Down Expand Up @@ -429,3 +431,86 @@ function Base.prod(
) where {L,R}
return ProductOf(push!(getleft(left), right), getright(left))
end

"""
TerminalProdArgument(argument)
`TerminalProdArgument` is a specialized wrapper structure. When used as an argument to the `prod` function, it returns itself without considering any product strategy
and does not perform any safety checks (e.g. `variate_form` or `support`). Attempting to calculate the product of two instances of `TerminalProdArgument` will raise an error.
Use `.argument` field to get the underlying wrapped argument.
"""
struct TerminalProdArgument{T}
argument::T
end

function Base.show(io::IO, prod::TerminalProdArgument)
return print(io, "TerminalProdArgument(", prod.argument, ")")
end
Base.convert(::Type{TerminalProdArgument}, something) = TerminalProdArgument(something)
Base.convert(::Type{TerminalProdArgument}, terminal::TerminalProdArgument) = terminal

BayesBase.paramfloattype(terminal::TerminalProdArgument) = paramfloattype(terminal.argument)
function BayesBase.convert_paramfloattype(
::Type{T}, terminal::TerminalProdArgument
) where {T}
return TerminalProdArgument(convert_paramfloattype(T, terminal.argument))
end

function default_prod_rule(::Type{<:TerminalProdArgument}, ::Type{T}) where {T}
return PreserveTypeProd(TerminalProdArgument)
end
function default_prod_rule(::Type{T}, ::Type{<:TerminalProdArgument}) where {T}
return PreserveTypeProd(TerminalProdArgument)
end
function default_prod_rule(::Type{<:TerminalProdArgument}, ::Type{<:TerminalProdArgument})
return PreserveTypeProd(TerminalProdArgument)
end

function Base.prod(
::PreserveTypeProd{TerminalProdArgument}, left::TerminalProdArgument, right
)
return left
end
function Base.prod(
::PreserveTypeProd{TerminalProdArgument}, left, right::TerminalProdArgument
)
return right
end
function Base.prod(
::PreserveTypeProd{TerminalProdArgument},
left::TerminalProdArgument,
right::TerminalProdArgument,
)
return error("Invalid product: `$(left)` × `$(right)`")
end

"""
resolve_prod_strategy(left, right)
Given two strategies, this function returns the one with higher priority, if possible.
"""
function resolve_prod_strategy(left, right)
return error(
"Cannot resolve strategies `$(left)` and $(right). Strategies have the same priority.",
)
end

function resolve_prod_strategy(left::T, right::T) where {T}
if (left === right)
return left
else
error(
"Cannot resolve strategies `$(left)` and $(right). Strategies have the same priority.",
)
end
end

resolve_prod_strategy(any, ::GenericProd) = any
resolve_prod_strategy(::GenericProd, any) = any
resolve_prod_strategy(::GenericProd, ::GenericProd) = GenericProd()
resolve_prod_strategy(::Nothing, ::GenericProd) = GenericProd()
resolve_prod_strategy(::GenericProd, ::Nothing) = GenericProd()

resolve_prod_strategy(::Nothing, any) = any
resolve_prod_strategy(any, ::Nothing) = any
resolve_prod_strategy(::Nothing, ::Nothing) = nothing
200 changes: 152 additions & 48 deletions test/prod_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,62 +24,112 @@ end
end

@testset "`ClosedProd` for distribution objects should assume `ProdPreserveType(Distribution)`" begin
@test prod(ClosedProd(), ADistributionObject(), ADistributionObject()) isa ADistributionObject
@test prod(ClosedProd(), ADistributionObject(), ADistributionObject()) isa
ADistributionObject
end

end

@testitem "PreserveTypeProd" begin
include("./prod_setuptests.jl")

@testset "`missing` should be ignored with the `PreserveTypeProd`" begin
# Can convert the result of the prod to the desired type
@test prod(PreserveTypeProd(SomeUnknownObject), missing, SomeUnknownObject()) isa SomeUnknownObject
@test prod(PreserveTypeProd(SomeUnknownObject), SomeUnknownObject(), missing) isa SomeUnknownObject
@test prod(PreserveTypeProd(SomeUnknownObject), missing, SomeUnknownObject()) isa
SomeUnknownObject
@test prod(PreserveTypeProd(SomeUnknownObject), SomeUnknownObject(), missing) isa
SomeUnknownObject
@test prod(PreserveTypeProd(Missing), missing, missing) isa Missing
@test prod(PreserveTypeProd(SomeUnknownObject), missing, missing) isa Missing
end

@testset "`PreserveTypeLeftProd` should preserve the type of the left argument" begin
@test prod(PreserveTypeLeftProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa ObjectWithClosedProd1
@test prod(PreserveTypeLeftProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa ObjectWithClosedProd2
@test prod(
PreserveTypeLeftProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()
) isa ObjectWithClosedProd1
@test prod(
PreserveTypeLeftProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()
) isa ObjectWithClosedProd2
end

@testset "`PreserveTypeRightProd` should preserve the type of the left argument" begin
@test prod(PreserveTypeRightProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa ObjectWithClosedProd2
@test prod(PreserveTypeRightProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa ObjectWithClosedProd1
@test prod(
PreserveTypeRightProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()
) isa ObjectWithClosedProd2
@test prod(
PreserveTypeRightProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()
) isa ObjectWithClosedProd1
end

@testset "`ProdPreserveType(T)` should preserve the desired type of `T`" begin
@test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa
ObjectWithClosedProd1
@test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa
ObjectWithClosedProd1
@test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa
ObjectWithClosedProd1
@test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa
ObjectWithClosedProd1

@test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa
ObjectWithClosedProd2
@test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa
ObjectWithClosedProd2
@test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa
ObjectWithClosedProd2
@test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa
ObjectWithClosedProd2
@test prod(
PreserveTypeProd(ObjectWithClosedProd1),
ObjectWithClosedProd1(),
ObjectWithClosedProd1(),
) isa ObjectWithClosedProd1
@test prod(
PreserveTypeProd(ObjectWithClosedProd1),
ObjectWithClosedProd1(),
ObjectWithClosedProd2(),
) isa ObjectWithClosedProd1
@test prod(
PreserveTypeProd(ObjectWithClosedProd1),
ObjectWithClosedProd2(),
ObjectWithClosedProd1(),
) isa ObjectWithClosedProd1
@test prod(
PreserveTypeProd(ObjectWithClosedProd1),
ObjectWithClosedProd2(),
ObjectWithClosedProd2(),
) isa ObjectWithClosedProd1

@test prod(
PreserveTypeProd(ObjectWithClosedProd2),
ObjectWithClosedProd1(),
ObjectWithClosedProd1(),
) isa ObjectWithClosedProd2
@test prod(
PreserveTypeProd(ObjectWithClosedProd2),
ObjectWithClosedProd1(),
ObjectWithClosedProd2(),
) isa ObjectWithClosedProd2
@test prod(
PreserveTypeProd(ObjectWithClosedProd2),
ObjectWithClosedProd2(),
ObjectWithClosedProd1(),
) isa ObjectWithClosedProd2
@test prod(
PreserveTypeProd(ObjectWithClosedProd2),
ObjectWithClosedProd2(),
ObjectWithClosedProd2(),
) isa ObjectWithClosedProd2

# The output can be converted to an `Int` (see the fixtures above)
@test prod(PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa Int
@test prod(PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa Int
@test prod(PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa Int
@test prod(PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa Int
@test prod(
PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd1()
) isa Int
@test prod(
PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd2()
) isa Int
@test prod(
PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd1()
) isa Int
@test prod(
PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd2()
) isa Int

# The output can not be converted to a `Float` (see the fixtures above)
@test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd1())
@test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd2())
@test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd1())
@test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd2())
@test_throws MethodError prod(
PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd1()
)
@test_throws MethodError prod(
PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd2()
)
@test_throws MethodError prod(
PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd1()
)
@test_throws MethodError prod(
PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd2()
)
end
end

Expand All @@ -91,14 +141,20 @@ end
@testset "GenericProd should use `default_prod_rule` where possible" begin

# `SomeUnknownObject` does not implement any prod rule (see the fixtures above)
@test SomeUnknownObject() × SomeUnknownObject() isa ProductOf{SomeUnknownObject, SomeUnknownObject}
@test ObjectWithClosedProd1() × SomeUnknownObject() isa ProductOf{ObjectWithClosedProd1, SomeUnknownObject}
@test SomeUnknownObject() × ObjectWithClosedProd1() isa ProductOf{SomeUnknownObject, ObjectWithClosedProd1}

@test getleft(ObjectWithClosedProd1() × SomeUnknownObject()) === ObjectWithClosedProd1()
@test getright(ObjectWithClosedProd1() × SomeUnknownObject()) === SomeUnknownObject()
@test SomeUnknownObject() × SomeUnknownObject() isa
ProductOf{SomeUnknownObject,SomeUnknownObject}
@test ObjectWithClosedProd1() × SomeUnknownObject() isa
ProductOf{ObjectWithClosedProd1,SomeUnknownObject}
@test SomeUnknownObject() × ObjectWithClosedProd1() isa
ProductOf{SomeUnknownObject,ObjectWithClosedProd1}

@test getleft(ObjectWithClosedProd1() × SomeUnknownObject()) ===
ObjectWithClosedProd1()
@test getright(ObjectWithClosedProd1() × SomeUnknownObject()) ===
SomeUnknownObject()
@test getleft(SomeUnknownObject() × ObjectWithClosedProd1()) === SomeUnknownObject()
@test getright(SomeUnknownObject() × ObjectWithClosedProd1()) === ObjectWithClosedProd1()
@test getright(SomeUnknownObject() × ObjectWithClosedProd1()) ===
ObjectWithClosedProd1()

# Both `ObjectWithClosedProd1` and `ObjectWithClosedProd2` implement `ClosedProd` as a default (see the fixtures above)
@test ObjectWithClosedProd1() × ObjectWithClosedProd1() isa ObjectWithClosedProd1
Expand Down Expand Up @@ -134,37 +190,85 @@ end
@testset "ProdGeneric should create a product tree if closed form product is not available" begin
d1 = SomeUnknownObject()

@test 1.0 × 1 × d1 isa ProductOf{ProductOf{Float64, Int}, SomeUnknownObject}
@test 1 × 1.0 × d1 isa ProductOf{ProductOf{Int, Float64}, SomeUnknownObject}
@test 1.0 × 1 × d1 isa ProductOf{ProductOf{Float64,Int},SomeUnknownObject}
@test 1 × 1.0 × d1 isa ProductOf{ProductOf{Int,Float64},SomeUnknownObject}
end

@testset "ProdGeneric should create a linearised product tree if closed form product is not available, but objects are of the same type" begin
d1 = SomeUnknownObject()
d2 = ObjectWithClosedProd1()

@test d1 × d1 isa ProductOf{SomeUnknownObject, SomeUnknownObject}
@test d1 × d1 isa ProductOf{SomeUnknownObject,SomeUnknownObject}

@testset let product = d1 × d1 × d1
@test product isa LinearizedProductOf{SomeUnknownObject}
@test length(product) === 3

# Test that the next prod rule should preserve the type of the linearized product
@test default_prod_rule(product, d1) isa PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}}
@test default_prod_rule(product, d1) isa
PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}}
end

@testset let product = (d1 × d1 × d1) × d1
@test product isa LinearizedProductOf{SomeUnknownObject}
@test length(product) === 4

# Test that the next prod rule should preserve the type of the linearized product
@test default_prod_rule(product, d1) isa PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}}
@test default_prod_rule(product, d1) isa
PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}}
end

@test d2 × d1 × d1 × d1 isa ProductOf{ObjectWithClosedProd1, LinearizedProductOf{SomeUnknownObject}}
@test d2 × d1 × d1 × d1 × d1 isa ProductOf{ObjectWithClosedProd1, LinearizedProductOf{SomeUnknownObject}}
@test d2 × d1 × d1 × d1 isa
ProductOf{ObjectWithClosedProd1,LinearizedProductOf{SomeUnknownObject}}
@test d2 × d1 × d1 × d1 × d1 isa
ProductOf{ObjectWithClosedProd1,LinearizedProductOf{SomeUnknownObject}}

# d2 × (...) × d2 should fold if closed prod is available
@test d2 × d1 × d1 × d1 × d1 × d2 == (d2 × d2) × d1 × d1 × d1 × d1
@test d2 × d1 × d2 × d1 × d2 × d1 × d1 × d2 == (d2 × d2 × d2 × d2) × d1 × d1 × d1 × d1
@test d2 × d1 × d2 × d1 × d2 × d1 × d1 × d2 ==
(d2 × d2 × d2 × d2) × d1 × d1 × d1 × d1
end
end

@testitem "TerminalProdArgument" begin
include("./prod_setuptests.jl")

d1 = SomeUnknownObject()
d2 = ObjectWithClosedProd1()

@test prod(GenericProd(), TerminalProdArgument(d1), d2) === TerminalProdArgument(d1)
@test prod(GenericProd(), d2, TerminalProdArgument(d1)) === TerminalProdArgument(d1)
@test prod(GenericProd(), TerminalProdArgument(d2), d1) === TerminalProdArgument(d2)
@test prod(GenericProd(), d1, TerminalProdArgument(d2)) === TerminalProdArgument(d2)
@test_throws ErrorException prod(
GenericProd(), TerminalProdArgument(d1), TerminalProdArgument(d2)
)
@test_throws ErrorException prod(
GenericProd(), TerminalProdArgument(d2), TerminalProdArgument(d1)
)
end

@testitem "resolve_prod_strategy" begin
for strategy in (
ClosedProd(),
PreserveTypeProd(Int),
PreserveTypeLeftProd(),
PreserveTypeRightProd(),
GenericProd(),
)
@test resolve_prod_strategy(strategy, strategy) === strategy
@test resolve_prod_strategy(strategy, GenericProd()) === strategy
@test resolve_prod_strategy(GenericProd(), strategy) === strategy
@test resolve_prod_strategy(nothing, strategy) === strategy
@test resolve_prod_strategy(strategy, nothing) === strategy
end

@test resolve_prod_strategy(nothing, nothing) === nothing

@test_throws ErrorException resolve_prod_strategy(
PreserveTypeLeftProd(), PreserveTypeRightProd()
)
@test_throws ErrorException resolve_prod_strategy(
PreserveTypeRightProd(), PreserveTypeLeftProd()
)
end

0 comments on commit 49e8652

Please sign in to comment.