From 043bae18ad12d1f7120a01a016f36249b7d87553 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Sep 2024 12:37:14 -0400 Subject: [PATCH] fix: pretty printing of MaxPool Layer (#891) --- Project.toml | 2 +- src/layers/display.jl | 4 +++- src/layers/pooling.jl | 12 +++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 537c8574ea..ad668eecf4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/layers/display.jl b/src/layers/display.jl index 6f5f52c644..c272083cf4 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -16,6 +16,8 @@ end show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: show_leaflike(x::AbstractLuxLayer) = false +isa_printable_leaf(x) = false + function underscorise(n::Integer) return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') end @@ -27,7 +29,7 @@ function big_show(io::IO, obj, indent::Int=0, name=nothing) return end children = printable_children(obj) - if all(show_leaflike, values(children)) + if all(show_leaflike, values(children)) || isa_printable_leaf(obj) layer_show(io, obj, indent, name) else println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), "(") diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index bc4da7b089..f29bc8db41 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -196,7 +196,7 @@ for layer_op in (:Max, :Mean, :LP) window; stride, pad, dilation, p)) end - function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name)) + function Base.show(io::IO, m::$(layer_name)) kernel_size = m.layer.mode.kernel_size print(io, string($(Meta.quot(layer_name))), "($(kernel_size)") pad = m.layer.mode.pad @@ -213,6 +213,8 @@ for layer_op in (:Max, :Mean, :LP) print(io, ")") end + PrettyPrinting.isa_printable_leaf(::$(layer_name)) = true + # Global Pooling Layer @doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <: AbstractLuxWrapperLayer{:layer} @@ -223,7 +225,7 @@ for layer_op in (:Max, :Mean, :LP) return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) end - function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name)) + function Base.show(io::IO, g::$(global_layer_name)) print(io, string($(Meta.quot(global_layer_name))), "(") if $(Meta.quot(op)) == :lp g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p) @@ -231,6 +233,8 @@ for layer_op in (:Max, :Mean, :LP) print(io, ")") end + PrettyPrinting.isa_printable_leaf(::$(global_layer_name)) = true + # Adaptive Pooling Layer @doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <: AbstractLuxWrapperLayer{:layer} @@ -242,12 +246,14 @@ for layer_op in (:Max, :Mean, :LP) static(:adaptive), $(Meta.quot(op)), out_size; p)) end - function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name)) + function Base.show(io::IO, a::$(adaptive_layer_name)) print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size) if $(Meta.quot(op)) == :lp a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) end print(io, ")") end + + PrettyPrinting.isa_printable_leaf(::$(adaptive_layer_name)) = true end end