diff --git a/src/states.jl b/src/states.jl index 42fcc9e..71a3908 100644 --- a/src/states.jl +++ b/src/states.jl @@ -68,6 +68,10 @@ Base.length(::XState) = 1 # ----------- some basic manipulations +_syms(nt::NamedTuple{SYMS, TT}) where {SYMS, TT} = SYMS +_syms(nt::Type{NamedTuple{SYMS, TT}}) where {SYMS, TT} = SYMS + + # extract the symbols and the types _syms(X::XState) = _syms(typeof(X)) _syms(::Type{<: XState{NamedTuple{SYMS, TT}}}) where {SYMS, TT} = SYMS diff --git a/src/structures.jl b/src/structures.jl index 6e320d9..14ed00b 100644 --- a/src/structures.jl +++ b/src/structures.jl @@ -2,7 +2,8 @@ import AtomsBase import AtomsBase: AbstractSystem, ChemicalElement, position, velocity, atomic_mass, atomic_number, - atomic_symbol + atomic_symbol, n_dimensions, bounding_box, + boundary_conditions, periodicity, get_cell # --------------------------------------------------- # an `Atom` is now just a `PState`, so we define @@ -13,6 +14,11 @@ symbol(::typeof(velocity)) = :𝐯 symbol(::typeof(atomic_mass)) = :m symbol(::typeof(atomic_symbol)) = :Z +const _atom_syms = (𝐫 = position, + 𝐯 = velocity, + m = atomic_mass, + Z = atomic_symbol) + position(atom::PState) = atom.𝐫 velocity(atom::PState) = atom.𝐯 atomic_mass(atom::PState) = atom.m @@ -27,7 +33,8 @@ atom(at; properties = (position, atomic_mass, atomic_symbol)) = PState((; [symbol(p) => p(at) for p in properties]...)) # --------------------------------------------------- - +# Array of Structs System +# mutable struct AosSystem{D, TCELL, TPART} <: AbstractSystem{D} cell::TCELL particles::Vector{TPART} @@ -46,12 +53,12 @@ function AosSystem(sys::AbstractSystem; end -# --------------------------------------------------- -# implementing the interface +# implementing the AtomsBase interface Base.length(at::AosSystem) = length(at.particles) Base.getindex(at::AosSystem, i::Int) = at.particles[i] +Base.getindex(at::AosSystem, inds::AbstractVector) = at.particles[inds] for f in (:position, :velocity, :atomic_mass, :atomic_symbol) @eval $f(sys::AosSystem) = [ $f(x) for x in sys.particles ] @@ -61,8 +68,53 @@ end AtomsBase.get_cell(at::AosSystem) = at.cell -AtomsBase.n_dimensions(at::AosSystem) = AtomsBase.n_dimensions(at.cell) -AtomsBase.bounding_box(at::AosSystem) = AtomsBase.bounding_box(at.cell) -AtomsBase.boundary_conditions(at::AosSystem) = AtomsBase.boundary_conditions(at.cell) -AtomsBase.periodicity(at::AosSystem) = AtomsBase.periodicity(at.cell) +for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity) + @eval $f(at::AosSystem) = $f(at.cell) +end + + +# --------------------------------------------------- +# Struct of Arrays System + +mutable struct SoaSystem{D, TCELL, NT} <: AbstractSystem{D} + cell::TCELL + arrays::NT + # -------- + meta::Dict{String, Any} +end + +function SoaSystem(sys::AbstractSystem; + properties = (position, atomic_mass, atomic_symbol), ) + + arrays = (; [symbol(p) => p(sys) for p in properties]... ) + cell = AtomsBase.get_cell(sys) + D = AtomsBase.n_dimensions(cell) + return SoaSystem{D, typeof(cell), typeof(arrays)}( + cell, arrays, Dict{String, Any}()) +end + + +# implementing the AtomsBase interface + +Base.length(at::SoaSystem) = length(at.particles) + +function Base.getindex(sys::SoaSystem{D, TCELL, NT}, i::Integer) where {D, TCELL, NT} + SYMS = _syms(NT) + return PState(; ntuple(a -> SYMS[a] => sys.arrays[SYMS[a]][i], length(SYMS))...) +end + +Base.getindex(sys::SoaSystem, inds::AbstractVector{<: Integer}) = + [ sys[i] for i in inds ] + +for f in (:position, :velocity, :atomic_mass, :atomic_symbol) + @eval $f(sys::SoaSystem) = getfield(sys.arrays, symbol($f)) + @eval $f(sys::SoaSystem, i::Integer) = getfield(sys.arrays, symbol($f))[i] + @eval $f(sys::SoaSystem, inds::AbstractVector) = getfield(sys.arrays, symbol($f))[inds] +end + +AtomsBase.get_cell(at::SoaSystem) = at.cell + +for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity) + @eval $f(at::SoaSystem) = $f(at.cell) +end diff --git a/test/test_atomsbase.jl b/test/test_atomsbase.jl index 9db321e..6330910 100644 --- a/test/test_atomsbase.jl +++ b/test/test_atomsbase.jl @@ -32,10 +32,17 @@ display(x) sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1); aos = DP.AosSystem(sys); +soa = DP.SoaSystem(sys); -aos[1] -aos[1, position] -aos[1, atomic_mass] -atomic_mass(aos, 1) +for i = 1:10 + @test aos[i] == soa[i] + + for f in (position, atomic_mass, atomic_symbol) + @test f(aos, i) == f(soa, i) + end +end + +for f in (get_cell, periodicity, boundary_conditions, bounding_box, n_dimensions) + @test f(aos) == f(soa) +end -get_cell(aos) \ No newline at end of file