Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global literal & temporary precision (for int & float) extended to gt4py and centralized #94

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ndsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import dsl # isort:skip
from .comm.communicator import CubedSphereCommunicator, TileCommunicator
from .comm.local_comm import LocalComm
from .comm.mpi import MPIComm
Expand Down
23 changes: 19 additions & 4 deletions ndsl/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import gt4py.cartesian.config
# Literal precision for both GT4Py & NDSL
import os
import sys

from ndsl.comm.mpi import MPI

gt4py_config_module = "gt4py.cartesian.config"
if gt4py_config_module in sys.modules:
raise RuntimeError(
"`GT4Py` config imported before `ndsl` imported."
" Please import `ndsl.dsl` or any `ndsl` module "
" before any `gt4py` imports."
)
NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64"))
os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION)


# Set cache names for default gt backends workflow
import gt4py.cartesian.config # noqa: E402

from ndsl.comm.mpi import MPI # noqa: E402


if MPI is not None:
Expand All @@ -9,5 +26,3 @@
gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get(
"GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}"
)

__version__ = "0.2.0"
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from ndsl.comm.communicator import Communicator
from ndsl.comm.partitioner import Partitioner
from ndsl.dsl import NDSL_GLOBAL_PRECISION
from ndsl.dsl.caches.cache_location import identify_code_path
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import floating_point_precision
from ndsl.optional_imports import cupy as cp


Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
)

if floating_point_precision() == 32:
if NDSL_GLOBAL_PRECISION == 32:
# When using 32-bit float, we flip the default dtypes to be all
# C, e.g. 32 bit.
dace.Config.set(
Expand Down
39 changes: 17 additions & 22 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Tuple, Union, cast
from typing import Tuple, TypeAlias, Union, cast

import gt4py.cartesian.gtscript as gtscript
import numpy as np

from ndsl.dsl import NDSL_GLOBAL_PRECISION


# A Field
Field = gtscript.Field
Expand All @@ -21,36 +22,30 @@
# Union of valid data types (from gt4py.cartesian.gtscript)
DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64]


def floating_point_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


# We redefine the type as a way to distinguish
# the model definition of a float to other usage of the
# common numpy type in the rest of the code.
NDSL_32BIT_FLOAT_TYPE = np.float32
NDSL_64BIT_FLOAT_TYPE = np.float64
NDSL_32BIT_FLOAT_TYPE: TypeAlias = np.float32
NDSL_32BIT_INT_TYPE: TypeAlias = np.int32
NDSL_64BIT_FLOAT_TYPE: TypeAlias = np.float64
NDSL_64BIT_INT_TYPE: TypeAlias = np.int64


def global_set_floating_point_precision():
def global_set_floating_point_precision() -> Tuple[TypeAlias, TypeAlias]:
"""Set the global floating point precision for all reference
to Float in the codebase. Defaults to 64 bit."""
global Float
precision_in_bit = floating_point_precision()
if precision_in_bit == 64:
return NDSL_64BIT_FLOAT_TYPE
elif precision_in_bit == 32:
return NDSL_32BIT_FLOAT_TYPE
else:
NotImplementedError(
f"{precision_in_bit} bit precision not implemented or tested"
)
global Float, Int
if NDSL_GLOBAL_PRECISION == 64:
return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE
elif NDSL_GLOBAL_PRECISION == 32:
return NDSL_32BIT_FLOAT_TYPE, NDSL_32BIT_INT_TYPE
raise NotImplementedError(
f"{NDSL_GLOBAL_PRECISION} bit precision not implemented or tested"
)


# Default float and int types
Float = global_set_floating_point_precision()
Int = np.int_
Float, Int = global_set_floating_point_precision()
Bool = np.bool_

FloatField = Field[gtscript.IJK, Float]
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import pytest


try:
import ndsl.dsl # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError("NDSL cannot be loaded")

try:
import gt4py
except ModuleNotFoundError:
Expand Down
Loading