diff --git a/Project.toml b/Project.toml index f50c957b4c..5b8252808c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.3.2" +version = "1.3.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index 43b4b009e5..4626e48fd1 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -205,7 +205,7 @@ used inside a `Chain`. account for the total number of parameters printed at the bottom. """ macro compact(_exs...) - return CompactMacroImpl.compact_macro_impl(_exs...) + return CompactMacroImpl.compact_macro_impl(__source__, __module__, _exs...) end """ @@ -434,7 +434,7 @@ using LuxCore: LuxCore, AbstractLuxLayer using ..Lux: Lux, CompactLuxLayer, LuxCompactModelParsingException, StatefulLuxLayer, safe_getproperty -function compact_macro_impl(_exs...) +function compact_macro_impl(__source__, __module__, _exs...) # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs if isempty(_exs) msg = "expects at least two expressions: a function and at least one keyword" @@ -499,15 +499,16 @@ function compact_macro_impl(_exs...) # edit expressions vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) - fex = supportself(fex, vars, splatted_kwargs) + fex = supportself(fex, vars, splatted_kwargs, __source__) # assemble return esc(:($CompactLuxLayer( - $(static)($(dispatch)), $fex, $name, ($layer, $input, $block), - (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) + $(static)($(dispatch)), $(fex), $(name), ($layer, $input, $block), + (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)) + )) end -function supportself(fex::Expr, vars, splatted_kwargs) +function supportself(fex::Expr, vars, splatted_kwargs, __source__) @gensym self ps st curried_f res # To avoid having to manipulate fex's arguments and body explicitly, we split the input # function body and add the required arguments to the function definition. @@ -562,6 +563,11 @@ function supportself(fex::Expr, vars, splatted_kwargs) else modified_body = flattened_expr end + + modified_body = MacroTools.to_line( + __source__, MacroTools.to_flag(modified_body) + ) + sdef[:body] = Expr(:let, Expr(:block, calls...), modified_body) sdef[:args] = args return combinedef(sdef)