diff --git a/src/Grids/Grids.jl b/src/Grids/Grids.jl index 39f0f235d2..5261c3197d 100644 --- a/src/Grids/Grids.jl +++ b/src/Grids/Grids.jl @@ -76,6 +76,7 @@ include("spectralelement.jl") include("extruded.jl") include("column.jl") include("level.jl") +include("convenience_constructors.jl") diff --git a/src/Grids/convenience_constructors.jl b/src/Grids/convenience_constructors.jl new file mode 100644 index 0000000000..fba855c940 --- /dev/null +++ b/src/Grids/convenience_constructors.jl @@ -0,0 +1,372 @@ +#= +These convenience constructors accept integer +keyword inputs, so they are dynamically created. You may +want to use a different constructor if you're making the +object in a performance-critical section, and if you know +the type parameters at compile time. + +If no convenience constructor exists to make the +grid you need, then you may need to use our lower +level compose-able API. +=# +check_device_context(context, device) = + @assert ClimaComms.device(context) == device "The given device and context device do not match." + +""" + ExtrudedCubedSphereGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + z_elem::Integer, + z_min::Real, + z_max::Real, + radius::Real, + h_elem::Integer, + n_quad_points::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.ShallowSphericalGlobalGeometry(radius), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + h_mesh = Meshes.EquiangularCubedSphere(Domains.SphereDomain(radius), h_elem), + ) + +A convenience constructor, which builds a +`ExtrudedFiniteDifferenceGrid`. +""" +ExtrudedCubedSphereGrid(; kwargs...) = + ExtrudedCubedSphereGrid(Float64; kwargs...) + +function ExtrudedCubedSphereGrid( + ::Type{FT}; + z_elem::Integer, + z_min::Real, + z_max::Real, + radius::Real, + h_elem::Integer, + n_quad_points::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.ShallowSphericalGlobalGeometry( + radius, + ), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + h_mesh = Meshes.EquiangularCubedSphere( + Domains.SphereDomain{FT}(radius), + h_elem, + ), +) where {FT} + check_device_context(context, device) + + z_boundary_names = (:bottom, :top) + h_topology = Topologies.Topology2D(context, h_mesh) + h_grid = Grids.SpectralElementGrid2D(h_topology, quad) + z_domain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(z_min), + Geometry.ZPoint{FT}(z_max); + boundary_names = z_boundary_names, + ) + z_mesh = Meshes.IntervalMesh(z_domain, stretch; nelems = z_elem) + z_topology = Topologies.IntervalTopology(context, z_mesh) + vertical_grid = FiniteDifferenceGrid(z_topology) + return ExtrudedFiniteDifferenceGrid( + h_grid, + vertical_grid, + hypsography, + global_geometry, + ) +end + +""" + CubedSphereGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + radius::Real, + n_quad_points::Integer, + h_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + h_mesh = Meshes.EquiangularCubedSphere(Domains.SphereDomain(radius), h_elem), + ) + +A convenience constructor, which builds a +`SpectralElementGrid2D`. +""" +CubedSphereGrid(; kwargs...) = CubedSphereGrid(Float64; kwargs...) +function CubedSphereGrid( + ::Type{FT}; + radius::Real, + n_quad_points::Integer, + h_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + h_mesh = Meshes.EquiangularCubedSphere( + Domains.SphereDomain{FT}(radius), + h_elem, + ), +) where {FT} + check_device_context(context, device) + h_topology = Topologies.Topology2D(context, h_mesh) + return Grids.SpectralElementGrid2D(h_topology, quad) +end + +""" + ColumnGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + z_elem::Integer, + z_min::Real, + z_max::Real, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + ) + +A convenience constructor, which builds a +`FiniteDifferenceGrid` given. +""" +ColumnGrid(; kwargs...) = ColumnGrid(Float64; kwargs...) +function ColumnGrid( + ::Type{FT}; + z_elem::Integer, + z_min::Real, + z_max::Real, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), +) where {FT} + check_device_context(context, device) + z_boundary_names = (:bottom, :top) + z_domain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(z_min), + Geometry.ZPoint{FT}(z_max); + boundary_names = z_boundary_names, + ) + z_mesh = Meshes.IntervalMesh(z_domain, stretch; nelems = z_elem) + z_topology = Topologies.IntervalTopology(context, z_mesh) + return FiniteDifferenceGrid(z_topology) +end + +""" + Box3DGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + z_elem::Integer, + x_min::Real, + x_max::Real, + y_min::Real, + y_max::Real, + z_min::Real, + z_max::Real, + periodic_x::Bool, + periodic_y::Bool, + n_quad_points::Integer, + x_elem::Integer, + y_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + ) + +A convenience constructor, which builds a +`ExtrudedFiniteDifferenceGrid` with a +`FiniteDifferenceGrid` vertical grid and a +`SpectralElementGrid2D` horizontal grid. +""" +Box3DGrid(; kwargs...) = Box3DGrid(Float64; kwargs...) +function Box3DGrid( + ::Type{FT}; + z_elem::Integer, + x_min::Real, + x_max::Real, + y_min::Real, + y_max::Real, + z_min::Real, + z_max::Real, + periodic_x::Bool, + periodic_y::Bool, + n_quad_points::Integer, + x_elem::Integer, + y_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), +) where {FT} + check_device_context(context, device) + x1boundary = (:east, :west) + x2boundary = (:south, :north) + z_boundary_names = (:bottom, :top) + domain = Domains.RectangleDomain( + Domains.IntervalDomain( + Geometry.XPoint{FT}(x_min), + Geometry.XPoint{FT}(x_max); + periodic = periodic_x, + boundary_names = x1boundary, + ), + Domains.IntervalDomain( + Geometry.YPoint{FT}(y_min), + Geometry.YPoint{FT}(y_max); + periodic = periodic_y, + boundary_names = x2boundary, + ), + ) + h_mesh = Meshes.RectilinearMesh(domain, x_elem, y_elem) + h_topology = Topologies.Topology2D(context, h_mesh) + h_grid = Grids.SpectralElementGrid2D(h_topology, quad) + z_domain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(z_min), + Geometry.ZPoint{FT}(z_max); + boundary_names = z_boundary_names, + ) + z_mesh = Meshes.IntervalMesh(z_domain, stretch; nelems = z_elem) + z_topology = Topologies.IntervalTopology(context, z_mesh) + vertical_grid = FiniteDifferenceGrid(z_topology) + return ExtrudedFiniteDifferenceGrid( + h_grid, + vertical_grid, + hypsography, + global_geometry, + ) +end + +""" + SliceXZGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + z_elem::Integer, + x_min::Real, + x_max::Real, + z_min::Real, + z_max::Real, + periodic_x::Bool, + n_quad_points::Integer, + x_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + ) + +A convenience constructor, which builds a +`ExtrudedFiniteDifferenceGrid` with a +`FiniteDifferenceGrid` vertical grid and a +`SpectralElementGrid1D` horizontal grid. + - `` +""" +SliceXZGrid(; kwargs...) = SliceXZGrid(Float64; kwargs...) +function SliceXZGrid( + ::Type{FT}; + z_elem::Integer, + x_min::Real, + x_max::Real, + z_min::Real, + z_max::Real, + periodic_x::Bool, + n_quad_points::Integer, + x_elem::Integer, + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + stretch::Meshes.StretchingRule = Meshes.Uniform(), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), +) where {FT} + check_device_context(context, device) + + x1boundary = (:east, :west) + z_boundary_names = (:bottom, :top) + h_domain = Domains.IntervalDomain( + Geometry.XPoint{FT}(x_min), + Geometry.XPoint{FT}(x_max); + periodic = periodic_x, + boundary_names = x1boundary, + ) + h_mesh = Meshes.IntervalMesh(h_domain; nelems = x_elem) + h_topology = Topologies.IntervalTopology(context, h_mesh) + h_grid = Grids.SpectralElementGrid1D(h_topology, quad) + z_domain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(z_min), + Geometry.ZPoint{FT}(z_max); + boundary_names = z_boundary_names, + ) + z_mesh = Meshes.IntervalMesh(z_domain, stretch; nelems = z_elem) + z_topology = Topologies.IntervalTopology(context, z_mesh) + vertical_grid = FiniteDifferenceGrid(z_topology) + return ExtrudedFiniteDifferenceGrid( + h_grid, + vertical_grid, + hypsography, + global_geometry, + ) +end + +""" + RectangleXYGrid( + ::Type{<:AbstractFloat}; # defaults to Float64 + x_min::Real, + x_max::Real, + y_min::Real, + y_max::Real, + periodic_x::Bool, + periodic_y::Bool, + n_quad_points::Integer, + x_elem::Integer, # number of horizontal elements + y_elem::Integer, # number of horizontal elements + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), + ) + +A convenience constructor, which builds a +`SpectralElementGrid2D` with a horizontal +`RectilinearMesh` mesh. +""" +RectangleXYGrid(; kwargs...) = RectangleXYGrid(Float64; kwargs...) +function RectangleXYGrid( + ::Type{FT}; + x_min::Real, + x_max::Real, + y_min::Real, + y_max::Real, + periodic_x::Bool, + periodic_y::Bool, + n_quad_points::Integer, + x_elem::Integer, # number of horizontal elements + y_elem::Integer, # number of horizontal elements + device::ClimaComms.AbstractDevice = ClimaComms.device(), + context::ClimaComms.AbstractCommsContext = ClimaComms.context(device), + hypsography::HypsographyAdaption = Flat(), + global_geometry::Geometry.AbstractGlobalGeometry = Geometry.CartesianGlobalGeometry(), + quad::Quadratures.QuadratureStyle = Quadratures.GLL{n_quad_points}(), +) where {FT} + check_device_context(context, device) + + x1boundary = (:east, :west) + x2boundary = (:south, :north) + domain = Domains.RectangleDomain( + Domains.IntervalDomain( + Geometry.XPoint{FT}(x_min), + Geometry.XPoint{FT}(x_max); + periodic = periodic_x, + boundary_names = x1boundary, + ), + Domains.IntervalDomain( + Geometry.YPoint{FT}(y_min), + Geometry.YPoint{FT}(y_max); + periodic = periodic_y, + boundary_names = x2boundary, + ), + ) + h_mesh = Meshes.RectilinearMesh(domain, x_elem, y_elem) + h_topology = Topologies.Topology2D(context, h_mesh) + return Grids.SpectralElementGrid2D(h_topology, quad) +end diff --git a/test/Grids/convenience_constructors.jl b/test/Grids/convenience_constructors.jl new file mode 100644 index 0000000000..1333b1896f --- /dev/null +++ b/test/Grids/convenience_constructors.jl @@ -0,0 +1,76 @@ +#= +julia --project +using Revise; include(joinpath("test", "Grids", "convenience_constructors.jl")) +=# +import ClimaComms +ClimaComms.@import_required_backends +using ClimaCore: Grids, Topologies, Meshes +using Test + +@testset "Convenience constructors" begin + grid = Grids.ExtrudedCubedSphereGrid(; + z_elem = 10, + z_min = 0, + z_max = 1, + radius = 10, + h_elem = 10, + n_quad_points = 4, + ) + @test grid isa Grids.ExtrudedFiniteDifferenceGrid + @test grid.horizontal_grid isa Grids.SpectralElementGrid2D + @test Grids.topology(grid.horizontal_grid).mesh isa + Meshes.EquiangularCubedSphere + + grid = Grids.CubedSphereGrid(; radius = 10, n_quad_points = 4, h_elem = 10) + @test grid isa Grids.SpectralElementGrid2D + @test Grids.topology(grid).mesh isa Meshes.EquiangularCubedSphere + + grid = Grids.ColumnGrid(; z_elem = 10, z_min = 0, z_max = 1) + @test grid isa Grids.FiniteDifferenceGrid + + grid = Grids.Box3DGrid(; + z_elem = 10, + x_min = 0, + x_max = 1, + y_min = 0, + y_max = 1, + z_min = 0, + z_max = 10, + x1periodic = false, + x2periodic = false, + n_quad_points = 4, + n1 = 3, + n2 = 4, + ) + @test grid isa Grids.ExtrudedFiniteDifferenceGrid + @test grid.horizontal_grid isa Grids.SpectralElementGrid2D + @test Grids.topology(grid.horizontal_grid).mesh isa Meshes.RectilinearMesh + + grid = Grids.SliceXZGrid(; + z_elem = 10, + x_min = 0, + x_max = 1, + z_min = 0, + z_max = 1, + x1periodic = false, + n_quad_points = 4, + n1 = 4, + ) + @test grid isa Grids.ExtrudedFiniteDifferenceGrid + @test grid.horizontal_grid isa Grids.SpectralElementGrid1D + @test Grids.topology(grid.horizontal_grid).mesh isa Meshes.IntervalMesh + + grid = Grids.RectangleXYGrid(; + x_min = 0, + x_max = 1, + y_min = 0, + y_max = 1, + x1periodic = false, + x2periodic = false, + n_quad_points = 4, + n1 = 3, + n2 = 4, + ) + @test grid isa Grids.SpectralElementGrid2D + @test Grids.topology(grid).mesh isa Meshes.RectilinearMesh +end diff --git a/test/runtests.jl b/test/runtests.jl index 9ffc4d01b5..e6705fd3ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,7 @@ UnitTest("Cubedsphere topology" ,"Topologies/cubedsphere.jl") UnitTest("Cubedsphere surface topology" ,"Topologies/cubedsphere_sfc.jl"), UnitTest("dss_transform" ,"Topologies/unit_dss_transform.jl"), UnitTest("Quadratures" ,"Quadratures/Quadratures.jl"), +UnitTest("Grids: Convenience constructors" ,"Grids/convenience_constructors.jl"), UnitTest("Spaces" ,"Spaces/unit_spaces.jl"), UnitTest("dss" ,"Spaces/unit_dss.jl"), UnitTest("Spaces - serial CPU DSS" ,"Spaces/ddss1.jl"),