Skip to content

Commit

Permalink
Merge pull request #11 from pedromxavier/px/upgrade-images
Browse files Browse the repository at this point in the history
Upgrade Mental Images Interface
  • Loading branch information
pedromxavier authored Apr 6, 2022
2 parents 58f96ce + 8139ef9 commit 4450fc3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
32 changes: 10 additions & 22 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ test_x = Array{Bool}[test_x[:, :, i] .>= 0.25 for i = 1:10_000]
train!(wnn, train_y, train_x)
s = count(classify(wnn, test_x; bleach=12) .== test_y)
s = count(classify(wnn, test_x; bleach=10) .== test_y)
println("Accuracy: $(100. * s / 10_000)%")
```
Expand All @@ -38,23 +38,17 @@ println("Accuracy: $(100. * s / 10_000)%")
```@example mnist
using Images
images = WiSARD.images(wnn)
images = WiSARD.images(RGB, wnn; w=28, h=28)
img = images[0]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[0]
```

```@example mnist
img = images[1]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[1]
```

```@example mnist
img = images[2]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[2]
```

## Fashion MNIST
Expand Down Expand Up @@ -82,7 +76,7 @@ test_x = Array{Bool}[test_x[:, :, i] .>= 0.25 for i = 1:10_000]
train!(wnn, train_y, train_x)
s = count(classify(wnn, test_x; bleach=12) .== test_y)
s = count(classify(wnn, test_x; bleach=10) .== test_y)
println("Accuracy: $(100. * s / 10_000)%")
```
Expand All @@ -91,21 +85,15 @@ println("Accuracy: $(100. * s / 10_000)%")
```@example fashion-mnist
using Images
images = WiSARD.images(wnn)
images = WiSARD.images(RGB, wnn; w=28, h=28)
img = images[0]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[0]
```

```@example fashion-mnist
img = images[1]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[1]
```

```@example fashion-mnist
img = images[2]
RGB.([img[(i - 1) * 28 + j] for i = 1:28, j = 1:28])
images[2]
```
46 changes: 41 additions & 5 deletions src/images/images.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@ function address(wnn::WNN{S, T}, k::T) where {S, T}
[j for j = 1:wnn.d if (k >> (j - 1)) % 2 == 1]
end

function images(wnn::WNN{S, T}, y::S) where {S, T}
function images(
N::Type{<:Any},
wnn::WNN{S, T},
y::S;
w::Union{Int, Nothing} = nothing,
h::Union{Int, Nothing} = nothing,
) where {S, T}
w, h = if isnothing(w) && isnothing(h)
1, wnn.n * wnn.d
elseif isnothing(w) || isnothing(h)
error("Both width 'w' and height 'h' must be informed, or none of them")
else
w, h
end

cls = wnn.cls[y]
img = zeros(Int, wnn.n * wnn.d)

Expand All @@ -17,12 +31,34 @@ function images(wnn::WNN{S, T}, y::S) where {S, T}
end

return if M == 0
Float64.(img)
N[img[(i - 1) * w + j] for i=1:h, j=1:w]
else
Float64.(img ./ M)
N[img[(i - 1) * w + j] / M for i=1:h, j=1:w]
end
end

function images(wnn::WNN{S, T}) where {S, T}
return Dict{S, Array{Float64}}(y => images(wnn, y) for y in keys(wnn.cls))
function images(
wnn::WNN{S, T},
y::S;
w::Union{Int, Nothing} = nothing,
h::Union{Int, Nothing} = nothing,
) where{S, T}
images(Float64, wnn, y; w=w, h=h)
end

function images(
N::Type{<:Any},
wnn::WNN{S, T};
w::Union{Int, Nothing} = nothing,
h::Union{Int, Nothing} = nothing,
) where {S, T}
Dict{S, Array{N}}(y => images(N, wnn, y; w=w, h=h) for y in keys(wnn.cls))
end

function images(
wnn::WNN;
w::Union{Int, Nothing} = nothing,
h::Union{Int, Nothing} = nothing,
)
images(Float64, wnn; w=w, h=h)
end
3 changes: 0 additions & 3 deletions src/model/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ end
Base.show(io::IO, wnn::WNN{S, T}) where {S <: Any, T <: BigInt} = print(io, "WNN[∞ bits, $(wnn.d) × $(wnn.n)]")
Base.show(io::IO, wnn::WNN{S, T}) where {S <: Any, T <: Unsigned} = print(io, "WNN[$(T.size * 8) bits, $(wnn.d) × $(wnn.n)]")

@doc raw"""
address(wnn::WNN, x::AbstractArray, i::Int)
"""
function address(wnn::WNN, x::AbstractArray, i::Int)
@inbounds sum(iszero(x[wnn.map[(i - 1) * wnn.d + j]]) ? 0 : 1 << (j - 1) for j = 1:wnn.d)
end
Expand Down

0 comments on commit 4450fc3

Please sign in to comment.