-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: update ConvMixer to support reactant (#1063)
* fix: update to new reactant changes * fix: use enzyme correctly * fix: update training code * feat: handle optimisers correctly * fix: don't transfer via CPU * refactor: remove promote_to handling * revert: specific version bumps * docs: add multiple CIFAR10 examples using Reactant * feat: BF16 training + inference * fix: incorrect AdamW handling * feat: implement resnet20 baseline * fix: conv mixer working again
- Loading branch information
Showing
13 changed files
with
405 additions
and
240 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Train Vision Models on CIFAR-10 | ||
|
||
✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 | ||
|
||
We have the following scripts to train vision models on CIFAR-10: | ||
|
||
1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers. | ||
2. `mlp_mixer.jl`: MLP-Mixer model. | ||
3. `conv_mixer.jl`: ConvMixer model. | ||
|
||
To get the options for each script, run the script with the `--help` flag. | ||
|
||
> [!NOTE] | ||
> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is | ||
> the recommended approach to train the models present in this directory. | ||
> [!NOTE] | ||
> Passing `--bfloat16` will use BFloat16 precision for training. This needs Julia 1.11 or | ||
> above. | ||
## Simple CNN | ||
|
||
```bash | ||
julia --startup-file=no \ | ||
--project=. \ | ||
--threads=auto \ | ||
simple_cnn.jl \ | ||
--backend=reactant | ||
``` | ||
|
||
On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training | ||
and test accuracies are 97% and 65%, respectively. | ||
|
||
## ResNet 20 | ||
|
||
```bash | ||
julia --startup-file=no \ | ||
--project=. \ | ||
--threads=auto \ | ||
resnet20.jl \ | ||
--backend=reactant | ||
``` | ||
|
||
On a RTX 3060 GPU, each epoch takes about 4.5 seconds and the final training and testing | ||
accuracy are 89% and 75% respectively. | ||
|
||
## ConvMixer | ||
|
||
> [!NOTE] | ||
> This code has been adapted from https://github.com/locuslab/convmixer-cifar10 | ||
This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point | ||
for new experiments on small datasets. | ||
|
||
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the | ||
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. | ||
|
||
```bash | ||
julia --startup-file=no \ | ||
--project=. \ | ||
--threads=auto \ | ||
conv_mixer.jl \ | ||
--backend=reactant | ||
``` | ||
|
||
### Notes | ||
|
||
1. To match the results from the original repo, we need more augmentation strategies, that | ||
are currently not implemented in DataAugmentation.jl. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, | ||
Printf, ProgressTables, Random, BFloat16s | ||
using Reactant, LuxCUDA | ||
|
||
@concrete struct TensorDataset | ||
dataset | ||
transform | ||
end | ||
|
||
Base.length(ds::TensorDataset) = length(ds.dataset) | ||
|
||
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) | ||
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) | ||
y = onehotbatch(ds.dataset.targets[idxs], 0:9) | ||
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y | ||
end | ||
|
||
function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T} | ||
cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T | ||
cifar10_std = (0.2471, 0.2435, 0.2616) .|> T | ||
|
||
train_transform = RandomResizeCrop((32, 32)) |> | ||
Maybe(FlipX{2}()) |> | ||
ImageToTensor() |> | ||
Normalize(cifar10_mean, cifar10_std) |> | ||
ToEltype(T) | ||
|
||
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T) | ||
|
||
trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform) | ||
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) | ||
|
||
testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform) | ||
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) | ||
|
||
return trainloader, testloader | ||
end | ||
|
||
function accuracy(model, ps, st, dataloader) | ||
total_correct, total = 0, 0 | ||
cdev = cpu_device() | ||
for (x, y) in dataloader | ||
target_class = onecold(cdev(y)) | ||
predicted_class = onecold(cdev(first(model(x, ps, st)))) | ||
total_correct += sum(target_class .== predicted_class) | ||
total += length(target_class) | ||
end | ||
return total_correct / total | ||
end | ||
|
||
function get_accelerator_device(backend::String) | ||
if backend == "gpu_if_available" | ||
return gpu_device() | ||
elseif backend == "gpu" | ||
return gpu_device(; force=true) | ||
elseif backend == "reactant" | ||
return reactant_device(; force=true) | ||
elseif backend == "cpu" | ||
return cpu_device() | ||
else | ||
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ | ||
`reactant`, and `cpu`.") | ||
end | ||
end | ||
|
||
function train_model( | ||
model, opt, scheduler=nothing; | ||
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25, | ||
bfloat16::Bool=false | ||
) | ||
rng = Random.default_rng() | ||
Random.seed!(rng, seed) | ||
|
||
prec = bfloat16 ? bf16 : f32 | ||
prec_jl = bfloat16 ? BFloat16 : Float32 | ||
prec_str = bfloat16 ? "BFloat16" : "Float32" | ||
@printf "[Info] Using %s precision\n" prec_str | ||
|
||
accelerator_device = get_accelerator_device(backend) | ||
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () | ||
trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |> | ||
accelerator_device | ||
|
||
ps, st = Lux.setup(rng, model) |> prec |> accelerator_device | ||
|
||
train_state = Training.TrainState(model, ps, st, opt) | ||
|
||
adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() | ||
|
||
if backend == "reactant" | ||
x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device | ||
@printf "[Info] Compiling model with Reactant.jl\n" | ||
st_test = Lux.testmode(st) | ||
model_compiled = Reactant.compile(model, (x_ra, ps, st_test)) | ||
@printf "[Info] Model compiled!\n" | ||
else | ||
model_compiled = model | ||
end | ||
|
||
loss_fn = CrossEntropyLoss(; logits=Val(true)) | ||
|
||
pt = ProgressTable(; | ||
header=[ | ||
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" | ||
], | ||
widths=[24, 24, 24, 24, 24], | ||
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], | ||
color=[:normal, :normal, :blue, :blue, :normal], | ||
border=true, | ||
alignment=[:center, :center, :center, :center, :center] | ||
) | ||
|
||
@printf "[Info] Training model\n" | ||
initialize(pt) | ||
|
||
for epoch in 1:epochs | ||
stime = time() | ||
lr = 0 | ||
for (i, (x, y)) in enumerate(trainloader) | ||
if scheduler !== nothing | ||
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) | ||
train_state = Optimisers.adjust!(train_state, lr) | ||
end | ||
(_, loss, _, train_state) = Training.single_train_step!( | ||
adtype, loss_fn, (x, y), train_state | ||
) | ||
isnan(loss) && error("NaN loss encountered!") | ||
end | ||
ttime = time() - stime | ||
|
||
train_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), trainloader | ||
) * 100 | ||
test_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), testloader | ||
) * 100 | ||
|
||
scheduler === nothing && (lr = NaN32) | ||
next(pt, [epoch, lr, train_acc, test_acc, ttime]) | ||
end | ||
|
||
finalize(pt) | ||
@printf "[Info] Finished training\n" | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote | ||
|
||
include("common.jl") | ||
|
||
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) | ||
#! format: off | ||
return Chain( | ||
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), | ||
BatchNorm(dim), | ||
[ | ||
Chain( | ||
SkipConnection( | ||
Chain( | ||
Conv( | ||
(kernel_size, kernel_size), dim => dim, gelu; | ||
groups=dim, pad=SamePad() | ||
), | ||
BatchNorm(dim) | ||
), | ||
+ | ||
), | ||
Conv((1, 1), dim => dim, gelu), | ||
BatchNorm(dim) | ||
) | ||
for _ in 1:depth | ||
]..., | ||
GlobalMeanPool(), | ||
FlattenLayer(), | ||
Dense(dim => 10) | ||
) | ||
#! format: on | ||
end | ||
|
||
Comonicon.@main function main(; | ||
batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, | ||
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, | ||
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, | ||
backend::String="reactant", bfloat16::Bool=false | ||
) | ||
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) | ||
|
||
opt = AdamW(; eta=lr_max, lambda=weight_decay) | ||
clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) | ||
|
||
lr_schedule = linear_interpolation( | ||
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] | ||
) | ||
|
||
return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16) | ||
end |
Oops, something went wrong.
2bf2ca8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
2bf2ca8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/122266
Tip: Release Notes
Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.
To add them here just re-invoke and the PR will be updated.
Tagging
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: