This repository has been archived by the owner on Nov 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cleanup cleanup fix improve docstring require cuDNN none functional only if cuDNN is functional separate cuDNN extension cleanup
- Loading branch information
1 parent
b914979
commit 30dcabc
Showing
34 changed files
with
719 additions
and
419 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
Manifest.toml | ||
*.cov | ||
generated | ||
build | ||
.vscode | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
module DeviceUtilsAMDGPUExt | ||
|
||
using Adapt: Adapt | ||
using AMDGPU: AMDGPU | ||
using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! | ||
using Random: Random | ||
|
||
__init__() = reset_gpu_device!() | ||
|
||
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) | ||
|
||
function _check_use_amdgpu!() | ||
USE_AMD_GPU[] === nothing || return | ||
|
||
USE_AMD_GPU[] = AMDGPU.functional() | ||
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) | ||
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ | ||
available." maxlog=1 | ||
end | ||
return | ||
end | ||
|
||
DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true | ||
function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool | ||
_check_use_amdgpu!() | ||
return USE_AMD_GPU[] | ||
end | ||
|
||
function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) | ||
return AMDGPUDevice(nothing) | ||
end | ||
function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) | ||
id > length(AMDGPU.devices()) && | ||
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) | ||
old_dev = AMDGPU.device() | ||
AMDGPU.device!(AMDGPU.devices()[id]) | ||
device = AMDGPUDevice(AMDGPU.device()) | ||
AMDGPU.device!(old_dev) | ||
return device | ||
end | ||
|
||
DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() | ||
|
||
# Query Device from Array | ||
function DeviceUtils.get_device(x::AMDGPU.AnyROCArray) | ||
parent_x = parent(x) | ||
parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) | ||
return DeviceUtils.get_device(parent_x) | ||
end | ||
|
||
# Set Device | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) | ||
return AMDGPU.device!(dev) | ||
end | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) | ||
return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) | ||
end | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) | ||
id = mod1(rank + 1, length(AMDGPU.devices())) | ||
return DeviceUtils.set_device!(AMDGPUDevice, id) | ||
end | ||
|
||
# Device Transfer | ||
## To GPU | ||
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) | ||
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) | ||
old_dev = AMDGPU.device() # remember the current device | ||
dev = DeviceUtils.get_device(x) | ||
if !(dev isa AMDGPUDevice) | ||
AMDGPU.device!(to.device) | ||
x_new = AMDGPU.roc(x) | ||
AMDGPU.device!(old_dev) | ||
return x_new | ||
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) | ||
return x | ||
else | ||
AMDGPU.device!(to.device) | ||
x_new = copy(x) | ||
AMDGPU.device!(old_dev) | ||
return x_new | ||
end | ||
end | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
module DeviceUtilsCUDAExt | ||
|
||
using Adapt: Adapt | ||
using CUDA: CUDA | ||
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector | ||
using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice, reset_gpu_device! | ||
using Random: Random | ||
|
||
function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) | ||
id > length(CUDA.devices()) && | ||
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) | ||
old_dev = CUDA.device() | ||
CUDA.device!(id - 1) | ||
device = CUDADevice(CUDA.device()) | ||
CUDA.device!(old_dev) | ||
return device | ||
end | ||
|
||
function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) | ||
return CUDADevice(nothing) | ||
end | ||
|
||
DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() | ||
|
||
# Query Device from Array | ||
function DeviceUtils.get_device(x::CUDA.AnyCuArray) | ||
parent_x = parent(x) | ||
parent_x === x && return CUDADevice(CUDA.device(x)) | ||
return DeviceUtils.get_device(parent_x) | ||
end | ||
function DeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) | ||
return CUDADevice(CUDA.device(x.nzVal)) | ||
end | ||
|
||
# Set Device | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) | ||
return CUDA.device!(dev) | ||
end | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) | ||
return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) | ||
end | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) | ||
id = mod1(rank + 1, length(CUDA.devices())) | ||
return DeviceUtils.set_device!(CUDADevice, id) | ||
end | ||
|
||
# Device Transfer | ||
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) | ||
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) | ||
old_dev = CUDA.device() # remember the current device | ||
dev = DeviceUtils.get_device(x) | ||
if !(dev isa CUDADevice) | ||
CUDA.device!(to.device) | ||
x_new = CUDA.cu(x) | ||
CUDA.device!(old_dev) | ||
return x_new | ||
elseif dev.device == to.device | ||
return x | ||
else | ||
CUDA.device!(to.device) | ||
x_new = copy(x) | ||
CUDA.device!(old_dev) | ||
return x_new | ||
end | ||
end | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() | ||
|
||
# Defining as extensions seems to case precompilation errors | ||
@static if isdefined(CUDA.CUSPARSE, :SparseArrays) | ||
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) | ||
return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) | ||
end | ||
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) | ||
return CUDA.CUSPARSE.SparseArrays.SparseVector(x) | ||
end | ||
else | ||
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ | ||
an issue in DeviceUtils.jl repository." | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module DeviceUtilsFillArraysExt | ||
|
||
using Adapt: Adapt | ||
using FillArrays: FillArrays, AbstractFill | ||
using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice | ||
|
||
Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x | ||
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module DeviceUtilsGPUArraysExt | ||
|
||
using Adapt: Adapt | ||
using GPUArrays: GPUArrays | ||
using DeviceUtils: CPUDevice | ||
using Random: Random | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module DeviceUtilsMetalExt | ||
|
||
using Adapt: Adapt | ||
using GPUArrays: GPUArrays | ||
using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! | ||
using Metal: Metal, MtlArray | ||
|
||
__init__() = reset_gpu_device!() | ||
|
||
DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true | ||
function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) | ||
return Metal.functional() | ||
end | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) | ||
|
||
# Query Device from Array | ||
DeviceUtils.get_device(::MtlArray) = MetalDevice() | ||
|
||
# Device Transfer | ||
## To GPU | ||
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module DeviceUtilsRecursiveArrayToolsExt | ||
|
||
using Adapt: Adapt, adapt | ||
using DeviceUtils: DeviceUtils, AbstractDevice | ||
using RecursiveArrayTools: VectorOfArray, DiffEqArray | ||
|
||
# We want to preserve the structure | ||
function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) | ||
return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) | ||
end | ||
|
||
function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) | ||
# Don't move the `time` to the GPU | ||
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) | ||
end | ||
|
||
function DeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) | ||
return mapreduce(DeviceUtils.get_device, DeviceUtils.__combine_devices, x.u) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module DeviceUtilsReverseDiffExt | ||
|
||
using DeviceUtils: DeviceUtils | ||
using ReverseDiff: ReverseDiff | ||
|
||
@inline function DeviceUtils.get_device(x::ReverseDiff.TrackedArray) | ||
return DeviceUtils.get_device(ReverseDiff.value(x)) | ||
end | ||
@inline function DeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) | ||
return DeviceUtils.get_device(ReverseDiff.value.(x)) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module DeviceUtilsSparseArraysExt | ||
|
||
using Adapt: Adapt | ||
using DeviceUtils: CPUDevice | ||
using SparseArrays: AbstractSparseArray | ||
|
||
Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
module DeviceUtilsTrackerExt | ||
|
||
using Adapt: Adapt | ||
using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, | ||
oneAPIDevice | ||
using Tracker: Tracker | ||
|
||
@inline function DeviceUtils.get_device(x::Tracker.TrackedArray) | ||
return DeviceUtils.get_device(Tracker.data(x)) | ||
end | ||
@inline function DeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) | ||
return DeviceUtils.get_device(Tracker.data.(x)) | ||
end | ||
|
||
@inline DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true | ||
|
||
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, | ||
CUDADevice{Nothing}, MetalDevice, oneAPIDevice) | ||
@eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) | ||
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ | ||
to Tracker.TrackedArray." maxlog=1 | ||
return to(Tracker.collect(x)) | ||
end | ||
end | ||
|
||
end |
Oops, something went wrong.