From 2bf2ca837816795f5c6f16fed6e629220008e49d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 1 Jan 2025 16:56:15 -0500 Subject: [PATCH] feat: update ConvMixer to support reactant (#1063) * fix: update to new reactant changes * fix: use enzyme correctly * fix: update training code * feat: handle optimisers correctly * fix: don't transfer via CPU * refactor: remove promote_to handling * revert: specific version bumps * docs: add multiple CIFAR10 examples using Reactant * feat: BF16 training + inference * fix: incorrect AdamW handling * feat: implement resnet20 baseline * fix: conv mixer working again --- Project.toml | 2 +- docs/Project.toml | 2 +- docs/src/.vitepress/config.mts | 4 +- docs/src/tutorials/index.md | 6 +- examples/{ConvMixer => CIFAR10}/Project.toml | 11 +- examples/CIFAR10/README.md | 69 +++++++++ examples/CIFAR10/common.jl | 146 +++++++++++++++++++ examples/CIFAR10/conv_mixer.jl | 50 +++++++ examples/CIFAR10/resnet20.jl | 78 ++++++++++ examples/CIFAR10/simple_cnn.jl | 34 +++++ examples/ConvMixer/README.md | 82 ----------- examples/ConvMixer/main.jl | 111 -------------- src/helpers/optimizers.jl | 50 ++----- 13 files changed, 405 insertions(+), 240 deletions(-) rename examples/{ConvMixer => CIFAR10}/Project.toml (82%) create mode 100644 examples/CIFAR10/README.md create mode 100644 examples/CIFAR10/common.jl create mode 100644 examples/CIFAR10/conv_mixer.jl create mode 100644 examples/CIFAR10/resnet20.jl create mode 100644 examples/CIFAR10/simple_cnn.jl delete mode 100644 examples/ConvMixer/README.md delete mode 100644 examples/ConvMixer/main.jl diff --git a/Project.toml b/Project.toml index a08546d73d..77788a3803 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.4.3" +version = "1.4.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/Project.toml b/docs/Project.toml index 0e561d7625..e0d9b02476 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -56,7 +56,7 @@ Optimisers = "0.4.1" Pkg = "1.10" Printf = "1.10" Random = "1.10" -Reactant = "0.2.12" +Reactant = "0.2.11" StableRNGs = "1" StaticArrays = "1" WeightInitializers = "1" diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index f785f6a316..bdd870e08f 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -243,8 +243,8 @@ export default defineConfig({ link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM", }, { - text: "ConvMixer on CIFAR-10", - link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + text: "Different Vision Models on CIFAR-10", + link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", }, ], }, diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 75c45f7b93..d639f309e1 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -97,10 +97,10 @@ const large_models = [ desc: "Train a Diffusion Model to generate images from Gaussian noises." }, { - href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", src: "https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp", - caption: "ConvMixer on CIFAR-10", - desc: "Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes." + caption: "Vision Models on CIFAR-10", + desc: "Train different vision models on CIFAR-10 to 90% accuracy within 10 minutes." } ]; diff --git a/examples/ConvMixer/Project.toml b/examples/CIFAR10/Project.toml similarity index 82% rename from examples/ConvMixer/Project.toml rename to examples/CIFAR10/Project.toml index 8ae7806576..c0dffde559 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/CIFAR10/Project.toml @@ -1,7 +1,9 @@ [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -11,18 +13,20 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +BFloat16s = "0.5.0" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" +Enzyme = "0.13.14" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" @@ -32,10 +36,9 @@ MLDatasets = "0.7.14" MLUtils = "0.4.4" OneHotArrays = "0.2.5" Optimisers = "0.4.1" -PreferenceTools = "0.1.2" Printf = "1.10" -ProgressBars = "1.5.1" Random = "1.10" +Reactant = "0.2.12" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/CIFAR10/README.md b/examples/CIFAR10/README.md new file mode 100644 index 0000000000..6eebe74e0e --- /dev/null +++ b/examples/CIFAR10/README.md @@ -0,0 +1,69 @@ +# Train Vision Models on CIFAR-10 + +✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 + +We have the following scripts to train vision models on CIFAR-10: + +1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers. +2. `mlp_mixer.jl`: MLP-Mixer model. +3. `conv_mixer.jl`: ConvMixer model. + +To get the options for each script, run the script with the `--help` flag. + +> [!NOTE] +> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is +> the recommended approach to train the models present in this directory. + +> [!NOTE] +> Passing `--bfloat16` will use BFloat16 precision for training. This needs Julia 1.11 or +> above. + +## Simple CNN + +```bash +julia --startup-file=no \ + --project=. \ + --threads=auto \ + simple_cnn.jl \ + --backend=reactant +``` + +On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training +and test accuracies are 97% and 65%, respectively. + +## ResNet 20 + +```bash +julia --startup-file=no \ + --project=. \ + --threads=auto \ + resnet20.jl \ + --backend=reactant +``` + +On a RTX 3060 GPU, each epoch takes about 4.5 seconds and the final training and testing +accuracy are 89% and 75% respectively. + +## ConvMixer + +> [!NOTE] +> This code has been adapted from https://github.com/locuslab/convmixer-cifar10 + +This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point +for new experiments on small datasets. + +You can get around **90.0%** accuracy in just **25 epochs** by running the script with the +following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. + +```bash +julia --startup-file=no \ + --project=. \ + --threads=auto \ + conv_mixer.jl \ + --backend=reactant +``` + +### Notes + + 1. To match the results from the original repo, we need more augmentation strategies, that + are currently not implemented in DataAugmentation.jl. diff --git a/examples/CIFAR10/common.jl b/examples/CIFAR10/common.jl new file mode 100644 index 0000000000..7457306d79 --- /dev/null +++ b/examples/CIFAR10/common.jl @@ -0,0 +1,146 @@ +using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, + Printf, ProgressTables, Random, BFloat16s +using Reactant, LuxCUDA + +@concrete struct TensorDataset + dataset + transform +end + +Base.length(ds::TensorDataset) = length(ds.dataset) + +function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) + img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) + y = onehotbatch(ds.dataset.targets[idxs], 0:9) + return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y +end + +function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T} + cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T + cifar10_std = (0.2471, 0.2435, 0.2616) .|> T + + train_transform = RandomResizeCrop((32, 32)) |> + Maybe(FlipX{2}()) |> + ImageToTensor() |> + Normalize(cifar10_mean, cifar10_std) |> + ToEltype(T) + + test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T) + + trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform) + trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) + + testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform) + testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) + + return trainloader, testloader +end + +function accuracy(model, ps, st, dataloader) + total_correct, total = 0, 0 + cdev = cpu_device() + for (x, y) in dataloader + target_class = onecold(cdev(y)) + predicted_class = onecold(cdev(first(model(x, ps, st)))) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end + +function get_accelerator_device(backend::String) + if backend == "gpu_if_available" + return gpu_device() + elseif backend == "gpu" + return gpu_device(; force=true) + elseif backend == "reactant" + return reactant_device(; force=true) + elseif backend == "cpu" + return cpu_device() + else + error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ + `reactant`, and `cpu`.") + end +end + +function train_model( + model, opt, scheduler=nothing; + backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25, + bfloat16::Bool=false +) + rng = Random.default_rng() + Random.seed!(rng, seed) + + prec = bfloat16 ? bf16 : f32 + prec_jl = bfloat16 ? BFloat16 : Float32 + prec_str = bfloat16 ? "BFloat16" : "Float32" + @printf "[Info] Using %s precision\n" prec_str + + accelerator_device = get_accelerator_device(backend) + kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () + trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |> + accelerator_device + + ps, st = Lux.setup(rng, model) |> prec |> accelerator_device + + train_state = Training.TrainState(model, ps, st, opt) + + adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() + + if backend == "reactant" + x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device + @printf "[Info] Compiling model with Reactant.jl\n" + st_test = Lux.testmode(st) + model_compiled = Reactant.compile(model, (x_ra, ps, st_test)) + @printf "[Info] Model compiled!\n" + else + model_compiled = model + end + + loss_fn = CrossEntropyLoss(; logits=Val(true)) + + pt = ProgressTable(; + header=[ + "Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" + ], + widths=[24, 24, 24, 24, 24], + format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], + color=[:normal, :normal, :blue, :blue, :normal], + border=true, + alignment=[:center, :center, :center, :center, :center] + ) + + @printf "[Info] Training model\n" + initialize(pt) + + for epoch in 1:epochs + stime = time() + lr = 0 + for (i, (x, y)) in enumerate(trainloader) + if scheduler !== nothing + lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) + train_state = Optimisers.adjust!(train_state, lr) + end + (_, loss, _, train_state) = Training.single_train_step!( + adtype, loss_fn, (x, y), train_state + ) + isnan(loss) && error("NaN loss encountered!") + end + ttime = time() - stime + + train_acc = accuracy( + model_compiled, train_state.parameters, + Lux.testmode(train_state.states), trainloader + ) * 100 + test_acc = accuracy( + model_compiled, train_state.parameters, + Lux.testmode(train_state.states), testloader + ) * 100 + + scheduler === nothing && (lr = NaN32) + next(pt, [epoch, lr, train_acc, test_acc, ttime]) + end + + finalize(pt) + @printf "[Info] Finished training\n" +end diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl new file mode 100644 index 0000000000..845baf8ac3 --- /dev/null +++ b/examples/CIFAR10/conv_mixer.jl @@ -0,0 +1,50 @@ +using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote + +include("common.jl") + +function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) + #! format: off + return Chain( + Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), + BatchNorm(dim), + [ + Chain( + SkipConnection( + Chain( + Conv( + (kernel_size, kernel_size), dim => dim, gelu; + groups=dim, pad=SamePad() + ), + BatchNorm(dim) + ), + + + ), + Conv((1, 1), dim => dim, gelu), + BatchNorm(dim) + ) + for _ in 1:depth + ]..., + GlobalMeanPool(), + FlattenLayer(), + Dense(dim => 10) + ) + #! format: on +end + +Comonicon.@main function main(; + batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, + backend::String="reactant", bfloat16::Bool=false +) + model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) + + opt = AdamW(; eta=lr_max, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + lr_schedule = linear_interpolation( + [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] + ) + + return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16) +end diff --git a/examples/CIFAR10/resnet20.jl b/examples/CIFAR10/resnet20.jl new file mode 100644 index 0000000000..c2b9db6dba --- /dev/null +++ b/examples/CIFAR10/resnet20.jl @@ -0,0 +1,78 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote + +include("common.jl") + +function ConvBN(kernel_size, (in_chs, out_chs), act; kwargs...) + return Chain( + Conv(kernel_size, in_chs => out_chs, act; kwargs...), + BatchNorm(out_chs) + ) +end + +function BasicBlock(in_channels, out_channels; stride=1) + connection = if (stride == 1 && in_channels == out_channels) + NoOpLayer() + else + Conv((3, 3), in_channels => out_channels, identity; stride=stride, pad=SamePad()) + end + return Chain( + Parallel( + +, + connection, + Chain( + ConvBN((3, 3), in_channels => out_channels, relu; stride, pad=SamePad()), + ConvBN((3, 3), out_channels => out_channels, identity; pad=SamePad()) + ) + ), + Base.BroadcastFunction(relu) + ) +end + +function ResNet20(; num_classes=10) + layers = [] + + # Initial Conv Layer + push!(layers, Chain( + Conv((3, 3), 3 => 16, relu; pad=SamePad()), + BatchNorm(16) + )) + + # Residual Blocks + block_configs = [ + (16, 16, 3, 1), # (in_channels, out_channels, num_blocks, stride) + (16, 32, 3, 2), + (32, 64, 3, 2) + ] + + for (in_channels, out_channels, num_blocks, stride) in block_configs + for i in 1:num_blocks + push!(layers, + BasicBlock( + i == 1 ? in_channels : out_channels, out_channels; + stride=(i == 1 ? stride : 1) + )) + end + end + + # Global Pooling and Final Dense Layer + push!(layers, GlobalMeanPool()) + push!(layers, FlattenLayer()) + push!(layers, Dense(64 => num_classes)) + + return Chain(layers...) +end + +Comonicon.@main function main(; + batchsize::Int=512, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=100, lr::Float64=0.001, + backend::String="reactant", bfloat16::Bool=false +) + model = ResNet20() + + opt = AdamW(; eta=lr, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + lr_schedule = nothing + + return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16) +end diff --git a/examples/CIFAR10/simple_cnn.jl b/examples/CIFAR10/simple_cnn.jl new file mode 100644 index 0000000000..9eed26f192 --- /dev/null +++ b/examples/CIFAR10/simple_cnn.jl @@ -0,0 +1,34 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +include("common.jl") + +function SimpleCNN() + return Chain( + Conv((3, 3), 3 => 16, gelu; stride=2, pad=1), + BatchNorm(16), + Conv((3, 3), 16 => 32, gelu; stride=2, pad=1), + BatchNorm(32), + Conv((3, 3), 32 => 64, gelu; stride=2, pad=1), + BatchNorm(64), + Conv((3, 3), 64 => 128, gelu; stride=2, pad=1), + BatchNorm(128), + GlobalMeanPool(), + FlattenLayer(), + Dense(128 => 64, gelu), + BatchNorm(64), + Dense(64 => 10) + ) +end + +Comonicon.@main function main(; + batchsize::Int=512, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, + backend::String="reactant", bfloat16::Bool=false +) + model = SimpleCNN() + + opt = AdamW(; eta=lr, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + return train_model(model, opt, nothing; backend, batchsize, seed, epochs, bfloat16) +end diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md deleted file mode 100644 index f072c10748..0000000000 --- a/examples/ConvMixer/README.md +++ /dev/null @@ -1,82 +0,0 @@ -# Train ConvMixer on CIFAR-10 - - ✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 - -> [!NOTE] -> This code has been adapted from https://github.com/locuslab/convmixer-cifar10 - -This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point -for new experiments on small datasets. - -You can get around **90.0%** accuracy in just **25 epochs** by running the script with the -following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. - -```bash -julia --startup-file=no \ - --project=. \ - --threads=auto \ - main.jl \ - --lr-max=0.05 \ - --weight-decay=0.0001 -``` - -Here's an example of the output of the above command (on a V100 32GB GPU): - -``` -Epoch 1: Learning Rate 5.05e-03, Train Acc: 56.91%, Test Acc: 56.49%, Time: 129.84 -Epoch 2: Learning Rate 1.01e-02, Train Acc: 69.75%, Test Acc: 68.40%, Time: 21.22 -Epoch 3: Learning Rate 1.51e-02, Train Acc: 76.86%, Test Acc: 74.73%, Time: 21.33 -Epoch 4: Learning Rate 2.01e-02, Train Acc: 81.03%, Test Acc: 78.14%, Time: 21.40 -Epoch 5: Learning Rate 2.51e-02, Train Acc: 72.71%, Test Acc: 70.29%, Time: 21.34 -Epoch 6: Learning Rate 3.01e-02, Train Acc: 83.12%, Test Acc: 80.20%, Time: 21.38 -Epoch 7: Learning Rate 3.51e-02, Train Acc: 82.38%, Test Acc: 78.66%, Time: 21.39 -Epoch 8: Learning Rate 4.01e-02, Train Acc: 84.24%, Test Acc: 79.97%, Time: 21.49 -Epoch 9: Learning Rate 4.51e-02, Train Acc: 84.93%, Test Acc: 80.18%, Time: 21.40 -Epoch 10: Learning Rate 5.00e-02, Train Acc: 84.97%, Test Acc: 80.26%, Time: 21.37 -Epoch 11: Learning Rate 4.52e-02, Train Acc: 89.09%, Test Acc: 83.53%, Time: 21.31 -Epoch 12: Learning Rate 4.05e-02, Train Acc: 91.62%, Test Acc: 85.10%, Time: 21.39 -Epoch 13: Learning Rate 3.57e-02, Train Acc: 93.71%, Test Acc: 86.78%, Time: 21.29 -Epoch 14: Learning Rate 3.10e-02, Train Acc: 95.14%, Test Acc: 87.23%, Time: 21.37 -Epoch 15: Learning Rate 2.62e-02, Train Acc: 95.36%, Test Acc: 87.08%, Time: 21.34 -Epoch 16: Learning Rate 2.15e-02, Train Acc: 97.07%, Test Acc: 87.91%, Time: 21.26 -Epoch 17: Learning Rate 1.67e-02, Train Acc: 98.67%, Test Acc: 89.57%, Time: 21.40 -Epoch 18: Learning Rate 1.20e-02, Train Acc: 99.41%, Test Acc: 89.77%, Time: 21.28 -Epoch 19: Learning Rate 7.20e-03, Train Acc: 99.81%, Test Acc: 90.31%, Time: 21.39 -Epoch 20: Learning Rate 2.50e-03, Train Acc: 99.94%, Test Acc: 90.83%, Time: 21.44 -Epoch 21: Learning Rate 2.08e-03, Train Acc: 99.96%, Test Acc: 90.83%, Time: 21.23 -Epoch 22: Learning Rate 1.66e-03, Train Acc: 99.97%, Test Acc: 90.91%, Time: 21.29 -Epoch 23: Learning Rate 1.25e-03, Train Acc: 99.99%, Test Acc: 90.82%, Time: 21.29 -Epoch 24: Learning Rate 8.29e-04, Train Acc: 99.99%, Test Acc: 90.79%, Time: 21.32 -Epoch 25: Learning Rate 4.12e-04, Train Acc: 100.00%, Test Acc: 90.83%, Time: 21.32 -``` - -## Usage - -```bash - main [options] [flags] - -Options - - --batchsize <512::Int> - --hidden-dim <256::Int> - --depth <8::Int> - --patch-size <2::Int> - --kernel-size <5::Int> - --weight-decay <0.01::Float64> - --seed <42::Int> - --epochs <25::Int> - --lr-max <0.01::Float64> - -Flags - --clip-norm - - -h, --help Print this help message. - --version Print version. -``` - -## Notes - - 1. To match the results from the original repo, we need more augmentation strategies, that - are currently not implemented in DataAugmentation.jl. - 2. Don't compare the reported timings in that repo against the numbers here. They time the - entire loop. We only time the training part of the loop. diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl deleted file mode 100644 index 03ddc63a52..0000000000 --- a/examples/ConvMixer/main.jl +++ /dev/null @@ -1,111 +0,0 @@ -using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, - MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, - StableRNGs, Statistics, Zygote - -CUDA.allowscalar(false) - -@concrete struct TensorDataset - dataset - transform -end - -Base.length(ds::TensorDataset) = length(ds.dataset) - -function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) - img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) - y = onehotbatch(ds.dataset.targets[idxs], 0:9) - return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y -end - -function get_dataloaders(batchsize) - cifar10_mean = (0.4914, 0.4822, 0.4465) - cifar10_std = (0.2471, 0.2435, 0.2616) - - train_transform = RandomResizeCrop((32, 32)) |> - Maybe(FlipX{2}()) |> - ImageToTensor() |> - Normalize(cifar10_mean, cifar10_std) - - test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) - - trainset = TensorDataset(CIFAR10(:train), train_transform) - trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true) - - testset = TensorDataset(CIFAR10(:test), test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true) - - return trainloader, testloader -end - -function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) - #! format: off - return Chain( - Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), - BatchNorm(dim), - [Chain( - SkipConnection( - Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, - pad=SamePad()), BatchNorm(dim)), +), - Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) - for _ in 1:depth]..., - GlobalMeanPool(), - FlattenLayer(), - Dense(dim => 10) - ) - #! format: on -end - -function accuracy(model, ps, st, dataloader) - total_correct, total = 0, 0 - st = Lux.testmode(st) - for (x, y) in dataloader - target_class = onecold(y) - predicted_class = onecold(first(model(x, ps, st))) - total_correct += sum(target_class .== predicted_class) - total += length(target_class) - end - return total_correct / total -end - -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, - clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01) - rng = StableRNG(seed) - - gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) .|> gdev - - model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) - ps, st = Lux.setup(rng, model) |> gdev - - opt = AdamW(; eta=lr_max, lambda=weight_decay) - clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - - train_state = Training.TrainState( - model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay)) - - lr_schedule = linear_interpolation( - [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) - - loss = CrossEntropyLoss(; logits=Val(true)) - - for epoch in 1:epochs - stime = time() - lr = 0 - for (i, (x, y)) in enumerate(trainloader) - lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) - train_state = Optimisers.adjust!(train_state, lr) - (_, _, _, train_state) = Training.single_train_step!( - AutoZygote(), loss, (x, y), train_state) - end - ttime = time() - stime - - train_acc = accuracy( - model, train_state.parameters, train_state.states, trainloader) * 100 - test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) * - 100 - - @printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \ - Time: %.2f\n" epoch lr train_acc test_acc ttime - end -end diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl index fe0116bb4e..2e3ee5c7e7 100644 --- a/src/helpers/optimizers.jl +++ b/src/helpers/optimizers.jl @@ -20,17 +20,21 @@ make_reactant_compatible(opt::ReactantCompatibleOptimisersRule) = opt function setfield_if_present(opt, field::Symbol, nt::NamedTuple) if hasfield(typeof(nt), field) - opt = Setfield.set( + return Setfield.set( opt, Setfield.PropertyLens{field}(), - convert( - typeof(getproperty(opt, field)), - Utils.to_rarray(getproperty(nt, field); track_numbers=true) - ) + convert(typeof(getproperty(opt, field)), getproperty(nt, field)) ) end return opt end +function Optimisers._adjust(opt::ReactantCompatibleOptimisersRule, nt::NamedTuple) + for field in fieldnames(typeof(opt)) + opt = setfield_if_present(opt, field, nt) + end + return opt +end + # OptimiserChain function make_reactant_compatible(opt::Optimisers.OptimiserChain) return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)) @@ -52,10 +56,6 @@ function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) return state, @. dx * η end -function Optimisers._adjust(opt::ReactantDescent, nt::NamedTuple) - return setfield_if_present(opt, :eta, nt) -end - # Momentum @concrete struct ReactantMomentum <: ReactantCompatibleOptimisersRule eta @@ -79,12 +79,6 @@ function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) return mvel, mvel end -function Optimisers._adjust(opt::ReactantMomentum, nt::NamedTuple) - opt = setfield_if_present(opt, :eta, nt) - opt = setfield_if_present(opt, :rho, nt) - return opt -end - # Adam @concrete struct ReactantAdam <: ReactantCompatibleOptimisersRule eta @@ -112,20 +106,13 @@ function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps mt, vt, βt = state - @. mt = β[1] * mt + (1 - β[1]) * dx - @. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + mt = @. β[1] * mt + (1 - β[1]) * dx + vt = @. β[2] * vt + (1 - β[2]) * abs2(dx) dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η return (mt, vt, βt .* β), dx′ end -function Optimisers._adjust(opt::ReactantAdam, nt::NamedTuple) - opt = setfield_if_present(opt, :eta, nt) - opt = setfield_if_present(opt, :beta, nt) - opt = setfield_if_present(opt, :epsilon, nt) - return opt -end - # AdamW @concrete struct ReactantAdamW <: ReactantCompatibleOptimisersRule eta @@ -158,9 +145,9 @@ function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) whe mt, vt, βt = state # standard Adam update with learning rate eta=1 - @. mt = β[1] * mt + (1 - β[1]) * dx - @. vt = β[2] * vt + (1 - β[2]) * abs2(dx) - dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η + mt = @. β[1] * mt + (1 - β[1]) * dx + vt = @. β[2] * vt + (1 - β[2]) * abs2(dx) + dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) # apply learning rate and weight decay if o.couple @@ -172,13 +159,4 @@ function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) whe return (mt, vt, βt .* β), dx′′ end -function Optimisers._adjust(opt::ReactantAdamW, nt::NamedTuple) - opt = setfield_if_present(opt, :eta, nt) - opt = setfield_if_present(opt, :beta, nt) - opt = setfield_if_present(opt, :lambda, nt) - opt = setfield_if_present(opt, :epsilon, nt) - opt = setfield_if_present(opt, :couple, nt) - return opt -end - end