Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ArrayPartition from RecursiveArrayTools.jl #118

Merged
merged 24 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6f07dcf
use VectorOfArray from RecursiveArrayTools.jl as datastructure for ti…
JoshuaLampert Jun 17, 2024
f8fa9d5
remove temporary vectors
JoshuaLampert Jun 17, 2024
e4ab265
some fixes
JoshuaLampert Jun 17, 2024
0605759
use view instead of creating vector
JoshuaLampert Jun 17, 2024
fd8d2dc
format
JoshuaLampert Jun 17, 2024
4e34bc4
bump lower compat of RecursiveArrayTools
JoshuaLampert Jun 17, 2024
633738d
set lower compat of RecursiveArrayTools to 3.3
JoshuaLampert Jun 17, 2024
c9dbd50
remove compat for SciMLBase v1.93
JoshuaLampert Jun 17, 2024
e0b3988
bump lower compat of SciMLBase to 2.11
JoshuaLampert Jun 17, 2024
be75464
bump lower compat of DiffEqBase to 6.130
JoshuaLampert Jun 17, 2024
81e828d
bump lower compat of SummationByPartsOperators to 0.5.50
JoshuaLampert Jun 17, 2024
d030e23
bump lower compat of DiffEqBase to 6.143
JoshuaLampert Jun 17, 2024
22d492b
bump lower compat of SummationByPartsOperators to 0.5.52
JoshuaLampert Jun 17, 2024
c3c0348
remove compat 0.17 of BandedMatrices
JoshuaLampert Jun 17, 2024
f6617db
bump lower compat in tests of SummationByPartsOperators to 0.5.52
JoshuaLampert Jun 17, 2024
fc04fee
bump more compats
JoshuaLampert Jun 17, 2024
6044989
use ArrayPartition instead of VectorOfArrays
JoshuaLampert Jun 18, 2024
cc4e3e5
introduce eachvariable(semi)
JoshuaLampert Jun 18, 2024
859f45c
format
JoshuaLampert Jun 18, 2024
ba147c6
fix plotting with conversion
JoshuaLampert Jun 18, 2024
cb0c62f
bump lower compat of RecipesBase to 1.2
JoshuaLampert Jun 18, 2024
5f672fe
bump lower compat in test of OrdinaryDiffEq to 6.62
JoshuaLampert Jun 18, 2024
e94e833
add allocation tests
JoshuaLampert Jun 19, 2024
0f5182c
format
JoshuaLampert Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PolynomialBases = "c74db56a-226d-5e98-8bb0-a6049094aeea"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -22,20 +23,21 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TrixiBase = "9a0f1c46-06d5-4909-a5a3-ce25d3fa3284"

[compat]
BandedMatrices = "0.17, 1"
DiffEqBase = "6.128"
BandedMatrices = "1"
DiffEqBase = "6.143"
Interpolations = "0.14.2, 0.15"
LinearAlgebra = "1"
PolynomialBases = "0.4.15"
Printf = "1"
RecipesBase = "1.1"
RecipesBase = "1.2"
RecursiveArrayTools = "3.3"
Reexport = "1.0"
Roots = "2.0.17"
SciMLBase = "1.93, 2"
SciMLBase = "2.11"
SimpleUnPack = "1.1"
SparseArrays = "1"
StaticArrays = "1"
SummationByPartsOperators = "0.5.41"
SummationByPartsOperators = "0.5.52"
TimerOutputs = "0.5.7"
TrixiBase = "0.1.3"
julia = "1.9"
1 change: 1 addition & 0 deletions src/DispersiveShallowWater.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra: mul!, ldiv!, I, Diagonal, Symmetric, diag, lu, cholesky, ch
using PolynomialBases: PolynomialBases
using Printf: @printf, @sprintf
using RecipesBase: RecipesBase, @recipe, @series
using RecursiveArrayTools: ArrayPartition
using Reexport: @reexport
using Roots: AlefeldPotraShi, find_zero

Expand Down
32 changes: 15 additions & 17 deletions src/callbacks_step/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ function integrals(cb::DiscreteCallback{Condition,
return (; zip(names, integrals)...)
end

function initialize!(cb::DiscreteCallback{Condition, Affect!}, u_ode, t,
function initialize!(cb::DiscreteCallback{Condition, Affect!}, q, t,
integrator) where {Condition, Affect! <: AnalysisCallback}
semi = integrator.p
initial_state_integrals = integrate(u_ode, semi)
initial_state_integrals = integrate(q, semi)

analysis_callback = cb.affect!
analysis_callback.initial_state_integrals = initial_state_integrals
Expand Down Expand Up @@ -184,8 +184,8 @@ end

# This method is just called internally from `(analysis_callback::AnalysisCallback)(integrator)`
# and serves as a function barrier. Additionally, it makes the code easier to profile and optimize.
function (analysis_callback::AnalysisCallback)(io, u_ode, integrator, semi)
_, equations, solver = mesh_equations_solver(semi)
function (analysis_callback::AnalysisCallback)(io, q, integrator, semi)
equations = semi.equations
@unpack analysis_errors, analysis_integrals, tstops, errors, integrals = analysis_callback
@unpack t, dt = integrator
push!(tstops, t)
Expand Down Expand Up @@ -247,7 +247,7 @@ function (analysis_callback::AnalysisCallback)(io, u_ode, integrator, semi)
println(io)

# Calculate L2/Linf errors, which are also returned
l2_error, linf_error = calc_error_norms(u_ode, t, semi)
l2_error, linf_error = calc_error_norms(q, t, semi)
current_errors = zeros(real(semi), (length(analysis_errors), nvariables(equations)))
current_errors[1, :] = l2_error
current_errors[2, :] = linf_error
Expand All @@ -266,7 +266,7 @@ function (analysis_callback::AnalysisCallback)(io, u_ode, integrator, semi)
# Conservation error
if :conservation_error in analysis_errors
@unpack initial_state_integrals = analysis_callback
state_integrals = integrate(u_ode, semi)
state_integrals = integrate(q, semi)
current_errors[3, :] = abs.(state_integrals - initial_state_integrals)
print(io, " |∫q - ∫q₀|: ")
for v in eachvariable(equations)
Expand All @@ -282,7 +282,7 @@ function (analysis_callback::AnalysisCallback)(io, u_ode, integrator, semi)
println(io, " Integrals: ")
end
current_integrals = zeros(real(semi), length(analysis_integrals))
analyze_integrals!(io, current_integrals, 1, analysis_integrals, u_ode, t, semi)
analyze_integrals!(io, current_integrals, 1, analysis_integrals, q, t, semi)
push!(integrals, current_integrals)

println(io, "─"^100)
Expand All @@ -292,26 +292,25 @@ end

# Iterate over tuples of analysis integrals in a type-stable way using "lispy tuple programming".
function analyze_integrals!(io, current_integrals, i, analysis_integrals::NTuple{N, Any},
u_ode,
t, semi) where {N}
q, t, semi) where {N}

# Extract the first analysis integral and process it; keep the remaining to be processed later
quantity = first(analysis_integrals)
remaining_quantities = Base.tail(analysis_integrals)

res = analyze(quantity, u_ode, t, semi)
res = analyze(quantity, q, t, semi)
current_integrals[i] = res
@printf(io, " %-12s:", pretty_form_utf(quantity))
@printf(io, " % 10.8e", res)
println(io)

# Recursively call this method with the unprocessed integrals
analyze_integrals!(io, current_integrals, i + 1, remaining_quantities, u_ode, t, semi)
analyze_integrals!(io, current_integrals, i + 1, remaining_quantities, q, t, semi)
return nothing
end

# terminate the type-stable iteration over tuples
function analyze_integrals!(io, current_integrals, i, analysis_integrals::Tuple{}, u_ode, t,
function analyze_integrals!(io, current_integrals, i, analysis_integrals::Tuple{}, q, t,
semi)
nothing
end
Expand All @@ -320,22 +319,21 @@ end
function (cb::DiscreteCallback{Condition, Affect!})(sol) where {Condition,
Affect! <:
AnalysisCallback}
analysis_callback = cb.affect!
semi = sol.prob.p

l2_error, linf_error = calc_error_norms(sol.u[end], sol.t[end], semi)

return (; l2 = l2_error, linf = linf_error)
end

function analyze(quantity, u_ode, t, semi::Semidiscretization)
integrate_quantity(u_ode -> quantity(u_ode, semi.equations), u_ode, semi)
function analyze(quantity, q, t, semi::Semidiscretization)
integrate_quantity(q -> quantity(q, semi.equations), q, semi)
end

# modified entropy from Svärd-Kalisch equations need to take the whole vector `u` for every point in space
function analyze(quantity::Union{typeof(energy_total_modified), typeof(entropy_modified)},
u_ode, t, semi::Semidiscretization)
integrate_quantity(quantity, u_ode, semi)
q, t, semi::Semidiscretization)
integrate_quantity(quantity, q, semi)
end

pretty_form_utf(::typeof(waterheight_total)) = "∫η"
Expand Down
40 changes: 17 additions & 23 deletions src/callbacks_step/relaxation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,45 +59,40 @@ end
@inline function (relaxation_callback::RelaxationCallback)(integrator)
semi = integrator.p
told = integrator.tprev
uold_ode = integrator.uprev
uold = wrap_array(uold_ode, semi)
qold = integrator.uprev
tnew = integrator.t
unew_ode = integrator.u
unew = wrap_array(unew_ode, semi)
qnew = integrator.u

gamma = one(tnew)
terminate_integration = false
gamma_lo = one(gamma) / 2
gamma_hi = 3 * one(gamma) / 2
gamma_lo = one(tnew) / 2
gamma_hi = 3 * one(tnew) / 2

function relaxation_functional(u, semi)
function relaxation_functional(q, semi)
@unpack tmp1 = semi.cache
# modified entropy from Svärd-Kalisch equations need to take the whole vector `u` for every point in space
# modified entropy from Svärd-Kalisch equations need to take the whole vector `q` for every point in space
if relaxation_callback.invariant isa
Union{typeof(energy_total_modified), typeof(entropy_modified)}
return integrate_quantity!(tmp1, relaxation_callback.invariant, u, semi;
wrap = false)
return integrate_quantity!(tmp1, relaxation_callback.invariant, q, semi)
else
return integrate_quantity!(tmp1,
u_ode -> relaxation_callback.invariant(u_ode,
semi.equations),
u, semi; wrap = false)
q -> relaxation_callback.invariant(q, semi.equations),
q, semi)
end
end

function convex_combination(gamma, uold, unew)
@. uold + gamma * (unew - uold)
function convex_combination(gamma, old, new)
@. old + gamma * (new - old)
end
energy_old = relaxation_functional(uold, semi)
energy_old = relaxation_functional(qold, semi)

@trixi_timeit timer() "relaxation" begin
if (relaxation_functional(convex_combination(gamma_lo, uold, unew), semi) -
if (relaxation_functional(convex_combination(gamma_lo, qold, qnew), semi) -
energy_old) *
(relaxation_functional(convex_combination(gamma_hi, uold, unew), semi) -
(relaxation_functional(convex_combination(gamma_hi, qold, qnew), semi) -
energy_old) > 0
terminate_integration = true
else
gamma = find_zero(g -> relaxation_functional(convex_combination(g, uold, unew),
gamma = find_zero(g -> relaxation_functional(convex_combination(g, qold, qnew),
semi) -
energy_old, (gamma_lo, gamma_hi), AlefeldPotraShi())
end
Expand All @@ -106,9 +101,8 @@ end
terminate_integration = true
end

unew .= convex_combination(gamma, uold, unew)
unew_ode = reshape(unew, prod(size(unew)))
DiffEqBase.set_u!(integrator, unew_ode)
qnew .= convex_combination(gamma, qold, qnew)
DiffEqBase.set_u!(integrator, qnew)
if !isapprox(tnew, first(integrator.opts.tstops))
tgamma = convex_combination(gamma, told, tnew)
DiffEqBase.set_t!(integrator, tgamma)
Expand Down
57 changes: 18 additions & 39 deletions src/equations/bbm_bbm_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ function create_cache(mesh, equations::BBMBBMEquations1D,
RealT, uEltype)
D = equations.D
invImD2 = lu(I - 1 / 6 * D^2 * sparse(solver.D2))
tmp2 = Array{RealT}(undef, nnodes(mesh))
return (invImD2 = invImD2, tmp2 = tmp2)
return (invImD2 = invImD2,)
end

function create_cache(mesh, equations::BBMBBMEquations1D,
Expand Down Expand Up @@ -164,26 +163,21 @@ function create_cache(mesh, equations::BBMBBMEquations1D,
else
@error "unknown type of first-derivative operator: $(typeof(solver.D1))"
end
tmp2 = Array{RealT}(undef, nnodes(mesh))
tmp3 = Array{RealT}(undef, nnodes(mesh) - 2)
return (invImD2d = invImD2d, invImD2n = invImD2n, tmp2 = tmp2, tmp3 = tmp3)
return (invImD2d = invImD2d, invImD2n = invImD2n)
end

# Discretization that conserves the mass (for eta and v) and the energy for periodic boundary conditions, see
# - Hendrik Ranocha, Dimitrios Mitsotakis and David I. Ketcheson (2020)
# A Broad Class of Conservative Numerical Methods for Dispersive Wave Equations
# [DOI: 10.4208/cicp.OA-2020-0119](https://doi.org/10.4208/cicp.OA-2020-0119)
function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_condition,
function rhs!(dq, q, t, mesh, equations::BBMBBMEquations1D, initial_condition,
::BoundaryConditionPeriodic, source_terms, solver, cache)
@unpack invImD2, tmp1, tmp2 = cache
@unpack invImD2 = cache

q = wrap_array(u_ode, mesh, equations, solver)
dq = wrap_array(du_ode, mesh, equations, solver)

eta = view(q, 1, :)
v = view(q, 2, :)
deta = view(dq, 1, :)
dv = view(dq, 2, :)
eta = q.x[1]
v = q.x[2]
deta = dq.x[1]
dv = dq.x[2]

D = equations.D
# energy and mass conservative semidiscretization
Expand All @@ -207,17 +201,11 @@ function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_cond
@trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations,
solver)

# To use the in-place version `ldiv!` instead of `\`, we need temporary arrays
# since `deta` and `dv` are not stored contiguously
@trixi_timeit timer() "deta elliptic" begin
tmp1[:] = deta
ldiv!(tmp2, invImD2, tmp1)
deta[:] = tmp2
ldiv!(invImD2, deta)
end
@trixi_timeit timer() "dv elliptic" begin
tmp2[:] = dv
ldiv!(tmp1, invImD2, tmp2)
dv[:] = tmp1
ldiv!(invImD2, dv)
end
return nothing
end
Expand All @@ -226,17 +214,14 @@ end
# - Hendrik Ranocha, Dimitrios Mitsotakis and David I. Ketcheson (2020)
# A Broad Class of Conservative Numerical Methods for Dispersive Wave Equations
# [DOI: 10.4208/cicp.OA-2020-0119](https://doi.org/10.4208/cicp.OA-2020-0119)
function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_condition,
function rhs!(dq, q, t, mesh, equations::BBMBBMEquations1D, initial_condition,
::BoundaryConditionReflecting, source_terms, solver, cache)
@unpack invImD2d, invImD2n, tmp1, tmp2, tmp3 = cache

q = wrap_array(u_ode, mesh, equations, solver)
dq = wrap_array(du_ode, mesh, equations, solver)
@unpack invImD2d, invImD2n = cache

eta = view(q, 1, :)
v = view(q, 2, :)
deta = view(dq, 1, :)
dv = view(dq, 2, :)
eta = q.x[1]
v = q.x[2]
deta = dq.x[1]
dv = dq.x[2]

D = equations.D
# energy and mass conservative semidiscretization
Expand All @@ -257,18 +242,12 @@ function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_cond
@trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations,
solver)

# To use the in-place version `ldiv!` instead of `\`, we need temporary arrays
# since `deta` and `dv` are not stored contiguously
@trixi_timeit timer() "deta elliptic" begin
tmp1[:] = deta
ldiv!(tmp2, invImD2n, tmp1)
deta[:] = tmp2
ldiv!(invImD2n, deta)
end
@trixi_timeit timer() "dv elliptic" begin
tmp2[:] = dv
ldiv!(tmp3, invImD2d, tmp2[2:(end - 1)])
ldiv!(invImD2d, (@view dv[2:(end - 1)]))
dv[1] = dv[end] = zero(eltype(dv))
dv[2:(end - 1)] = tmp3
end

return nothing
Expand Down
Loading
Loading