Skip to content

Commit

Permalink
SoaSytem draft, some type instabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 22, 2024
1 parent a34d094 commit 4a515a6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 13 deletions.
4 changes: 4 additions & 0 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Base.length(::XState) = 1

# ----------- some basic manipulations

_syms(nt::NamedTuple{SYMS, TT}) where {SYMS, TT} = SYMS
_syms(nt::Type{NamedTuple{SYMS, TT}}) where {SYMS, TT} = SYMS


# extract the symbols and the types
_syms(X::XState) = _syms(typeof(X))
_syms(::Type{<: XState{NamedTuple{SYMS, TT}}}) where {SYMS, TT} = SYMS
Expand Down
68 changes: 60 additions & 8 deletions src/structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import AtomsBase
import AtomsBase: AbstractSystem, ChemicalElement,
position, velocity, atomic_mass, atomic_number,
atomic_symbol
atomic_symbol, n_dimensions, bounding_box,
boundary_conditions, periodicity, get_cell

# ---------------------------------------------------
# an `Atom` is now just a `PState`, so we define
Expand All @@ -13,6 +14,11 @@ symbol(::typeof(velocity)) = :𝐯
symbol(::typeof(atomic_mass)) = :m
symbol(::typeof(atomic_symbol)) = :Z

const _atom_syms = (𝐫 = position,
𝐯 = velocity,
m = atomic_mass,
Z = atomic_symbol)

position(atom::PState) = atom.𝐫
velocity(atom::PState) = atom.𝐯
atomic_mass(atom::PState) = atom.m
Expand All @@ -27,7 +33,8 @@ atom(at; properties = (position, atomic_mass, atomic_symbol)) =
PState((; [symbol(p) => p(at) for p in properties]...))

# ---------------------------------------------------

# Array of Structs System
#
mutable struct AosSystem{D, TCELL, TPART} <: AbstractSystem{D}
cell::TCELL
particles::Vector{TPART}
Expand All @@ -46,12 +53,12 @@ function AosSystem(sys::AbstractSystem;
end


# ---------------------------------------------------
# implementing the interface
# implementing the AtomsBase interface

Base.length(at::AosSystem) = length(at.particles)

Base.getindex(at::AosSystem, i::Int) = at.particles[i]
Base.getindex(at::AosSystem, inds::AbstractVector) = at.particles[inds]

for f in (:position, :velocity, :atomic_mass, :atomic_symbol)
@eval $f(sys::AosSystem) = [ $f(x) for x in sys.particles ]
Expand All @@ -61,8 +68,53 @@ end

AtomsBase.get_cell(at::AosSystem) = at.cell

AtomsBase.n_dimensions(at::AosSystem) = AtomsBase.n_dimensions(at.cell)
AtomsBase.bounding_box(at::AosSystem) = AtomsBase.bounding_box(at.cell)
AtomsBase.boundary_conditions(at::AosSystem) = AtomsBase.boundary_conditions(at.cell)
AtomsBase.periodicity(at::AosSystem) = AtomsBase.periodicity(at.cell)
for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity)
@eval $f(at::AosSystem) = $f(at.cell)
end


# ---------------------------------------------------
# Struct of Arrays System

mutable struct SoaSystem{D, TCELL, NT} <: AbstractSystem{D}
cell::TCELL
arrays::NT
# --------
meta::Dict{String, Any}
end

function SoaSystem(sys::AbstractSystem;
properties = (position, atomic_mass, atomic_symbol), )

arrays = (; [symbol(p) => p(sys) for p in properties]... )
cell = AtomsBase.get_cell(sys)
D = AtomsBase.n_dimensions(cell)
return SoaSystem{D, typeof(cell), typeof(arrays)}(
cell, arrays, Dict{String, Any}())
end


# implementing the AtomsBase interface

Base.length(at::SoaSystem) = length(at.particles)

function Base.getindex(sys::SoaSystem{D, TCELL, NT}, i::Integer) where {D, TCELL, NT}
SYMS = _syms(NT)
return PState(; ntuple(a -> SYMS[a] => sys.arrays[SYMS[a]][i], length(SYMS))...)
end

Base.getindex(sys::SoaSystem, inds::AbstractVector{<: Integer}) =
[ sys[i] for i in inds ]

for f in (:position, :velocity, :atomic_mass, :atomic_symbol)
@eval $f(sys::SoaSystem) = getfield(sys.arrays, symbol($f))
@eval $f(sys::SoaSystem, i::Integer) = getfield(sys.arrays, symbol($f))[i]
@eval $f(sys::SoaSystem, inds::AbstractVector) = getfield(sys.arrays, symbol($f))[inds]
end

AtomsBase.get_cell(at::SoaSystem) = at.cell

for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity)
@eval $f(at::SoaSystem) = $f(at.cell)
end

17 changes: 12 additions & 5 deletions test/test_atomsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ display(x)

sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1);
aos = DP.AosSystem(sys);
soa = DP.SoaSystem(sys);

aos[1]
aos[1, position]
aos[1, atomic_mass]
atomic_mass(aos, 1)
for i = 1:10
@test aos[i] == soa[i]

for f in (position, atomic_mass, atomic_symbol)
@test f(aos, i) == f(soa, i)
end
end

for f in (get_cell, periodicity, boundary_conditions, bounding_box, n_dimensions)
@test f(aos) == f(soa)
end

get_cell(aos)

0 comments on commit 4a515a6

Please sign in to comment.