diff --git a/docs/src/examples.md b/docs/src/examples.md index b9ba895..59796b0 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -29,7 +29,7 @@ test_x = Array{Bool}[test_x[:, :, i] .>= 0.25 for i = 1:10_000] train!(wnn, train_y, train_x) -s = count(classify(wnn, test_x; bleach=12) .== test_y) +s = count(classify(wnn, test_x; bleach=10) .== test_y) println("Accuracy: $(100. * s / 10_000)%") ``` @@ -38,23 +38,17 @@ println("Accuracy: $(100. * s / 10_000)%") ```@example mnist using Images -images = WiSARD.images(wnn) +images = WiSARD.images(RGB, wnn; w=28, h=28) -img = images[0] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[0] ``` ```@example mnist -img = images[1] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[1] ``` ```@example mnist -img = images[2] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[2] ``` ## Fashion MNIST @@ -82,7 +76,7 @@ test_x = Array{Bool}[test_x[:, :, i] .>= 0.25 for i = 1:10_000] train!(wnn, train_y, train_x) -s = count(classify(wnn, test_x; bleach=12) .== test_y) +s = count(classify(wnn, test_x; bleach=10) .== test_y) println("Accuracy: $(100. * s / 10_000)%") ``` @@ -91,21 +85,15 @@ println("Accuracy: $(100. * s / 10_000)%") ```@example fashion-mnist using Images -images = WiSARD.images(wnn) +images = WiSARD.images(RGB, wnn; w=28, h=28) -img = images[0] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[0] ``` ```@example fashion-mnist -img = images[1] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[1] ``` ```@example fashion-mnist -img = images[2] - -RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28]) +images[2] ``` \ No newline at end of file diff --git a/src/images/images.jl b/src/images/images.jl index c28f7a1..2390849 100644 --- a/src/images/images.jl +++ b/src/images/images.jl @@ -2,7 +2,21 @@ function address(wnn::WNN{S, T}, k::T) where {S, T} [j for j = 1:wnn.d if (k >> (j - 1)) % 2 == 1] end -function images(wnn::WNN{S, T}, y::S) where {S, T} +function images( + N::Type{<:Any}, + wnn::WNN{S, T}, + y::S; + w::Union{Int, Nothing} = nothing, + h::Union{Int, Nothing} = nothing, + ) where {S, T} + w, h = if isnothing(w) && isnothing(h) + 1, wnn.n * wnn.d + elseif isnothing(w) || isnothing(h) + error("Both width 'w' and height 'h' must be informed, or none of them") + else + w, h + end + cls = wnn.cls[y] img = zeros(Int, wnn.n * wnn.d) @@ -17,12 +31,34 @@ function images(wnn::WNN{S, T}, y::S) where {S, T} end return if M == 0 - Float64.(img) + N[img[(i - 1) * w + j] for i=1:h, j=1:w] else - Float64.(img ./ M) + N[img[(i - 1) * w + j] / M for i=1:h, j=1:w] end end -function images(wnn::WNN{S, T}) where {S, T} - return Dict{S, Array{Float64}}(y => images(wnn, y) for y in keys(wnn.cls)) +function images( + wnn::WNN{S, T}, + y::S; + w::Union{Int, Nothing} = nothing, + h::Union{Int, Nothing} = nothing, + ) where{S, T} + images(Float64, wnn, y; w=w, h=h) +end + +function images( + N::Type{<:Any}, + wnn::WNN{S, T}; + w::Union{Int, Nothing} = nothing, + h::Union{Int, Nothing} = nothing, + ) where {S, T} + Dict{S, Array{N}}(y => images(N, wnn, y; w=w, h=h) for y in keys(wnn.cls)) +end + +function images( + wnn::WNN; + w::Union{Int, Nothing} = nothing, + h::Union{Int, Nothing} = nothing, + ) + images(Float64, wnn; w=w, h=h) end \ No newline at end of file diff --git a/src/model/model.jl b/src/model/model.jl index a052d3f..caef95a 100644 --- a/src/model/model.jl +++ b/src/model/model.jl @@ -45,9 +45,6 @@ end Base.show(io::IO, wnn::WNN{S, T}) where {S <: Any, T <: BigInt} = print(io, "WNN[∞ bits, $(wnn.d) × $(wnn.n)]") Base.show(io::IO, wnn::WNN{S, T}) where {S <: Any, T <: Unsigned} = print(io, "WNN[$(T.size * 8) bits, $(wnn.d) × $(wnn.n)]") -@doc raw""" - address(wnn::WNN, x::AbstractArray, i::Int) -""" function address(wnn::WNN, x::AbstractArray, i::Int) @inbounds sum(iszero(x[wnn.map[(i - 1) * wnn.d + j]]) ? 0 : 1 << (j - 1) for j = 1:wnn.d) end