Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented allreduce & Allreduce for Communicator #92

Merged
merged 22 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bf6e3d3
Merge pull request #50 from NOAA-GFDL/develop
fmalatino Sep 30, 2024
75b4741
Added an MPI all_reduce for quantities based on SUM operation to comm…
gmao-ckung Dec 11, 2024
4c8632c
linted
gmao-ckung Dec 11, 2024
a2fac9f
Add initial skeleton of pytest test for all reduce
gmao-ckung Dec 13, 2024
8c5b5d5
Added assertion tests for 1, 2 and 3D quantities passed through mpi_a…
gmao-ckung Dec 13, 2024
fb4e740
Linted
gmao-ckung Dec 13, 2024
34f82fb
Added pytest.mark to skip test if mpi4py isn't available
gmao-ckung Dec 13, 2024
b4a6a54
lint changes
gmao-ckung Dec 16, 2024
f5ce883
Addressed PR comments and added additional CPU backends to unit test
gmao-ckung Dec 16, 2024
2e41349
Merge branch 'feature/mpi_allreduce_sum' of https://github.com/NOAA-G…
gmao-ckung Dec 16, 2024
2e669db
Added setters for various Quantity properties to enable setting of Qu…
gmao-ckung Dec 18, 2024
fd2fa97
Added function in QuantityMetadata class that allows copying of Metad…
gmao-ckung Dec 19, 2024
cc620c6
Add `Allreduce` and all MPI OP
FlorianDeconinck Dec 22, 2024
0e8089e
Update utest
FlorianDeconinck Dec 22, 2024
2188c75
Fix `local_comm`
FlorianDeconinck Dec 22, 2024
f8cc2ce
Fix utest
FlorianDeconinck Dec 22, 2024
7ad271f
Enforce `comm_abc.Comm` into Communicator
FlorianDeconinck Dec 22, 2024
07cd0f3
Fix `comm` object in serial utest
FlorianDeconinck Dec 22, 2024
224e6e2
Lint + `MPIComm` on testing architecture
FlorianDeconinck Dec 22, 2024
312b492
Merge branch 'develop' into feature/mpi_allreduce_sum
FlorianDeconinck Dec 22, 2024
760578c
Add in_place option for Allreduce
FlorianDeconinck Dec 30, 2024
c758ffb
Merge branch 'develop' into feature/mpi_allreduce_sum
FlorianDeconinck Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ndsl/comm/caching_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request


T = TypeVar("T")
Expand Down Expand Up @@ -147,9 +147,12 @@ def Split(self, color, key) -> "CachingCommReader":
new_data = self._data.get_split()
return CachingCommReader(data=new_data)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
return self._data.get_generic_obj()

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce")

@classmethod
def load(cls, file: BinaryIO) -> "CachingCommReader":
data = CachingCommData.load(file)
Expand Down Expand Up @@ -229,7 +232,10 @@ def Split(self, color, key) -> "CachingCommWriter":
def dump(self, file: BinaryIO):
self._data.dump(file)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
result = self._comm.allreduce(sendobj, op)
self._data.generic_obj_buffers.append(copy.deepcopy(result))
return result

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommWriter.Allreduce")
29 changes: 28 additions & 1 deletion ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import abc
import enum
from typing import List, Optional, TypeVar


T = TypeVar("T")


@enum.unique
class ReductionOperator(enum.Enum):
OP_NULL = enum.auto()
MAX = enum.auto()
MIN = enum.auto()
SUM = enum.auto()
PROD = enum.auto()
LAND = enum.auto()
BAND = enum.auto()
LOR = enum.auto()
BOR = enum.auto()
LXOR = enum.auto()
BXOR = enum.auto()
MAXLOC = enum.auto()
MINLOC = enum.auto()
REPLACE = enum.auto()
NO_OP = enum.auto()


class Request(abc.ABC):
@abc.abstractmethod
def wait(self):
Expand Down Expand Up @@ -69,5 +89,12 @@ def Split(self, color, key) -> "Comm":
...

@abc.abstractmethod
def allreduce(self, sendobj: T, op=None) -> T:
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
...

@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
...

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
...
79 changes: 68 additions & 11 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import ndsl.constants as constants
from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer
from ndsl.comm.boundary import Boundary
from ndsl.comm.comm_abc import Comm as CommABC
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner
from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
from ndsl.performance.timer import NullTimer, Timer
Expand Down Expand Up @@ -44,7 +46,11 @@ def to_numpy(array, dtype=None) -> np.ndarray:

class Communicator(abc.ABC):
def __init__(
self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None
self,
comm: CommABC,
partitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
):
self.comm = comm
self.partitioner: Partitioner = partitioner
Expand All @@ -61,7 +67,7 @@ def tile(self) -> "TileCommunicator":
@abc.abstractmethod
def from_layout(
cls,
comm,
comm: CommABC,
layout: Tuple[int, int],
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand Down Expand Up @@ -93,17 +99,63 @@ def _device_synchronize():
# this is a method so we can profile it separately from other device syncs
device_synchronize()

def _create_all_reduce_quantity(
self, input_metadata: QuantityMetadata, input_data
) -> Quantity:
"""Create a Quantity for all_reduce data and metadata"""
all_reduce_quantity = Quantity(
input_data,
dims=input_metadata.dims,
units=input_metadata.units,
origin=input_metadata.origin,
extent=input_metadata.extent,
gt4py_backend=input_metadata.gt4py_backend,
allow_mismatch_float_precision=False,
)
return all_reduce_quantity

def all_reduce(
self,
input_quantity: Quantity,
op: ReductionOperator,
output_quantity: Quantity = None,
):
reduced_quantity_data = self.comm.allreduce(input_quantity.data, op)
if output_quantity is None:
all_reduce_quantity = self._create_all_reduce_quantity(
input_quantity.metadata, reduced_quantity_data
)
return all_reduce_quantity
else:
if output_quantity.data.shape != input_quantity.data.shape:
raise TypeError("Shapes not matching")

input_quantity.metadata.duplicate_metadata(output_quantity.metadata)

output_quantity.data = reduced_quantity_data

def all_reduce_per_element(
self,
input_quantity: Quantity,
output_quantity: Quantity,
op: ReductionOperator,
):
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def all_reduce_per_element_in_place(
self, quantity: Quantity, op: ReductionOperator
):
self.comm.Allreduce_inplace(quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
self.comm.Scatter(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Scatter(send, recv, **kwargs)

def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
self.comm.Gather(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Gather(send, recv, **kwargs)

def scatter(
self,
Expand Down Expand Up @@ -709,7 +761,7 @@ class CubedSphereCommunicator(Communicator):

def __init__(
self,
comm,
comm: CommABC,
partitioner: CubedSpherePartitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand All @@ -722,6 +774,11 @@ def __init__(
force_cpu: Force all communication to go through central memory.
timer: Time communication operations.
"""
if not issubclass(type(comm), CommABC):
raise TypeError(
"Communictor needs to be instantiated with communication subsytem"
f" derived from `comm_abc.Comm`, got {type(comm)}."
)
if comm.Get_size() != partitioner.total_ranks:
raise ValueError(
f"was given a partitioner for {partitioner.total_ranks} ranks but a "
Expand Down
10 changes: 8 additions & 2 deletions ndsl/comm/local_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,14 @@ def Split(self, color, key):
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op=None, recvobj=None) -> Any:
raise NotImplementedError(
"sendrecv fundamentally cannot be written for LocalComm, "
"allreduce fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)

def Allreduce(self, sendobj, recvobj, op) -> Any:
raise NotImplementedError(
"Allreduce fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)
41 changes: 37 additions & 4 deletions ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
try:
import mpi4py
from mpi4py import MPI
except ImportError:
MPI = None
from typing import List, Optional, TypeVar, cast
from typing import Dict, List, Optional, TypeVar, cast

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request
from ndsl.logging import ndsl_log


T = TypeVar("T")


class MPIComm(Comm):
_op_mapping: Dict[ReductionOperator, mpi4py.MPI.Op] = {
ReductionOperator.OP_NULL: mpi4py.MPI.OP_NULL,
ReductionOperator.MAX: mpi4py.MPI.MAX,
ReductionOperator.MIN: mpi4py.MPI.MIN,
ReductionOperator.SUM: mpi4py.MPI.SUM,
ReductionOperator.PROD: mpi4py.MPI.PROD,
ReductionOperator.LAND: mpi4py.MPI.LAND,
ReductionOperator.BAND: mpi4py.MPI.BAND,
ReductionOperator.LOR: mpi4py.MPI.LOR,
ReductionOperator.BOR: mpi4py.MPI.BOR,
ReductionOperator.LXOR: mpi4py.MPI.LXOR,
ReductionOperator.BXOR: mpi4py.MPI.BXOR,
ReductionOperator.MAXLOC: mpi4py.MPI.MAXLOC,
ReductionOperator.MINLOC: mpi4py.MPI.MINLOC,
ReductionOperator.REPLACE: mpi4py.MPI.REPLACE,
ReductionOperator.NO_OP: mpi4py.MPI.NO_OP,
}

def __init__(self):
if MPI is None:
raise RuntimeError("MPI not available")
Expand Down Expand Up @@ -72,8 +91,22 @@ def Split(self, color, key) -> "Comm":
)
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op=None) -> T:
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.allreduce(sendobj, op)
return self._comm.allreduce(sendobj, self._op_mapping[op])

def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op])

def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce (in place) on rank %s with operator %s",
self._comm.Get_rank(),
op,
)
return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op])
10 changes: 7 additions & 3 deletions ndsl/comm/null_comm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from typing import Any, Mapping
from typing import Any, Mapping, Optional

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request


class NullAsyncResult(Request):
Expand Down Expand Up @@ -91,5 +91,9 @@ def Split(self, color, key):
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
return self._fill_value

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
recvobj = sendobj
return recvobj
14 changes: 14 additions & 0 deletions ndsl/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def np(self) -> NumpyModule:
f"quantity underlying data is of unexpected type {self.data_type}"
)

def duplicate_metadata(self, metadata_copy):
metadata_copy.origin = self.origin
metadata_copy.extent = self.extent
metadata_copy.dims = self.dims
metadata_copy.units = self.units
metadata_copy.data_type = self.data_type
metadata_copy.dtype = self.dtype
metadata_copy.gt4py_backend = self.gt4py_backend


@dataclasses.dataclass
class QuantityHaloSpec:
Expand Down Expand Up @@ -492,6 +501,11 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]:
"""the underlying array of data"""
return self._data

@data.setter
def data(self, inputData):
if type(inputData) in [np.ndarray, cupy.ndarray]:
self._data = inputData

@property
def origin(self) -> Tuple[int, ...]:
"""the start of the computational domain"""
Expand Down
11 changes: 5 additions & 6 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CubedSphereCommunicator,
TileCommunicator,
)
from ndsl.comm.mpi import MPI
from ndsl.comm.mpi import MPI, MPIComm
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
from ndsl.dsl.dace.dace_config import DaceConfig
from ndsl.namelist import Namelist
Expand Down Expand Up @@ -323,7 +323,7 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode):
npx=namelist.npx,
npy=namelist.npy,
npz=namelist.npz,
communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode),
communicator=get_communicator(MPIComm(), layout, topology_mode),
backend=backend,
)

Expand Down Expand Up @@ -377,13 +377,12 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str):
metafunc.config
)
# get MPI environment
comm = MPI.COMM_WORLD
mpi_rank = comm.Get_rank()
comm = MPIComm()
savepoint_cases = parallel_savepoint_cases(
metafunc,
data_path,
namelist_filename,
mpi_rank,
comm.Get_rank(),
backend=backend,
comm=comm,
)
Expand All @@ -393,7 +392,7 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str):


def get_communicator(comm, layout, topology_mode):
if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "cubed-sphere"):
if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"):
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
communicator = CubedSphereCommunicator(comm, partitioner)
else:
Expand Down
Loading
Loading