Skip to content

Commit

Permalink
make soa-getindex generated
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 22, 2024
1 parent 4a515a6 commit 426faf1
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 6 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,41 @@ g = Zygote.gradient(f, [x1, x2])[1]

g[1].𝐫 2 * x1.𝐫
# true

# ---------------------------------------------------
# Prototype AtomsBase system implementations
# Both AosSystem and SoaSystem are fully flexible regarding the
# properties of the particles.

using AtomsBuilder
sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1); # AtomsBase.FlexibleSystem
aos = DP.AosSystem(sys);
soa = DP.SoaSystem(sys);

x1 = aos[1] # PState, just sys.particles[1]
x2 = soa[1] # PState, generated from the arrays in sys
isbits(x1) # true
isbits(x2) # true

# accessors are non-allocating:
_check_allocs(sys) = ( @allocated position(sys, 1) +
@allocated atomic_mass(sys, 1) +
@allocated sys[1] )
_check_allocs(sys) # 288
_check_allocs(aos) # 0
_check_allocs(soa) # 0

# this has performance implications
using BenchmarkTools

# Silly test 1 : sum up the positions via `position(sys, i)` accessor
silly_test_1(sys) = sum( position(sys, i) for i = 1:length(sys) )
@btime silly_test_1($sys) # 8.819 μs (320 allocations: 12.00 KiB)
@btime silly_test_1($aos) # 50.447 ns (0 allocations: 0 bytes)
@btime silly_test_1($soa) # 50.405 ns (0 allocations: 0 bytes)

silly_test_2(sys) = sum( position(x) for x in sys )
@btime silly_test_2($sys) # 10.750 μs (256 allocations: 18.00 KiB)
@btime silly_test_2($aos) # 47.950 ns (0 allocations: 0 bytes)
@btime silly_test_2($soa) # 48.794 ns (0 allocations: 0 bytes)
```
31 changes: 26 additions & 5 deletions src/structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,26 @@ end

# implementing the AtomsBase interface

Base.length(at::SoaSystem) = length(at.particles)

function Base.getindex(sys::SoaSystem{D, TCELL, NT}, i::Integer) where {D, TCELL, NT}
SYMS = _syms(NT)
return PState(; ntuple(a -> SYMS[a] => sys.arrays[SYMS[a]][i], length(SYMS))...)
Base.length(at::SoaSystem) = length(at.arrays[1])

# this implementation seems canonical but appears to be type unstable
# unclear to me exactly why, maybe it can be fixed.
# function Base.getindex(sys::SoaSystem{D, TCELL, NT}, i::Integer) where {D, TCELL, NT}
# SYMS = _syms(NT)
# return PState(; ntuple(a -> SYMS[a] => sys.arrays[SYMS[a]][i], length(SYMS))...)
# end

@generated function Base.getindex(sys::SoaSystem{D, TCELL, NT}, i::Integer) where {D, TCELL, NT}
SYMS = _syms(NT)
# very naive code-writing ... probably there is a nicer way ...
code = "PState("
for sym in SYMS
code *= "$(sym) = sys.arrays.$sym[i], "
end
code *= ")"
return quote
$(Meta.parse(code))
end
end

Base.getindex(sys::SoaSystem, inds::AbstractVector{<: Integer}) =
Expand All @@ -118,3 +133,9 @@ for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity)
@eval $f(at::SoaSystem) = $f(at.cell)
end




# ---------------------------------------------------------------
# Extension of the AtomsBase interface with setter functions

37 changes: 37 additions & 0 deletions test/_readme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,40 @@ g = Zygote.gradient(f, [x1, x2])[1]

g[1].𝐫 2 * x1.𝐫
# true

# ---------------------------------------------------
# Prototype AtomsBase system implementations
# Both AosSystem and SoaSystem are fully flexible regarding the
# properties of the particles.

using AtomsBuilder
sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1); # AtomsBase.FlexibleSystem
aos = DP.AosSystem(sys);
soa = DP.SoaSystem(sys);

x1 = aos[1] # PState, just sys.particles[1]
x2 = soa[1] # PState, generated from the arrays in sys
isbits(x1) # true
isbits(x2) # true

# accessors are non-allocating:
_check_allocs(sys) = ( @allocated position(sys, 1) +
@allocated atomic_mass(sys, 1) +
@allocated sys[1] )
_check_allocs(sys) # 288
_check_allocs(aos) # 0
_check_allocs(soa) # 0

# this has performance implications
using BenchmarkTools

# Silly test 1 : sum up the positions via `position(sys, i)` accessor
silly_test_1(sys) = sum( position(sys, i) for i = 1:length(sys) )
@btime silly_test_1($sys) # 8.819 μs (320 allocations: 12.00 KiB)
@btime silly_test_1($aos) # 50.447 ns (0 allocations: 0 bytes)
@btime silly_test_1($soa) # 50.405 ns (0 allocations: 0 bytes)

silly_test_2(sys) = sum( position(x) for x in sys )
@btime silly_test_2($sys) # 10.750 μs (256 allocations: 18.00 KiB)
@btime silly_test_2($aos) # 47.950 ns (0 allocations: 0 bytes)
@btime silly_test_2($soa) # 48.794 ns (0 allocations: 0 bytes)
30 changes: 29 additions & 1 deletion test/test_atomsbase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

using DecoratedParticles, AtomsBase, StaticArrays, Unitful, Test
using DecoratedParticles, AtomsBase, StaticArrays, Unitful, Test,
BenchmarkTools
using AtomsBase: ChemicalElement, Atom
using AtomsBuilder: bulk, rattle!
DP = DecoratedParticles
Expand Down Expand Up @@ -46,3 +47,30 @@ for f in (get_cell, periodicity, boundary_conditions, bounding_box, n_dimensions
@test f(aos) == f(soa)
end


##
# some performance related tests

sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1);
aos = DP.AosSystem(sys);
soa = DP.SoaSystem(sys);

x1 = aos[1]
x2 = soa[1]
@test isbits(x1)
@test isbits(x2)

@info("Checking allocations during accessors, not sure why it shows anything?")
function _check_allocs(sys)
a1 = @allocated position(sys, 1)
a2 = @allocated atomic_mass(sys, 1)
a3 = @allocated sys[1]
return a1 + a2 + a3
end

@test _check_allocs(aos) == 0
@test _check_allocs(soa) == 0


##

0 comments on commit 426faf1

Please sign in to comment.