From b72d0ad7dbbed852873026e08ed3e793265e7718 Mon Sep 17 00:00:00 2001 From: Federico Busato <50413820+fbusato@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:26:32 -0800 Subject: [PATCH 1/4] Fx thread-reduce performance regression (#3225) --- cub/cub/thread/thread_reduce.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cub/cub/thread/thread_reduce.cuh b/cub/cub/thread/thread_reduce.cuh index 5727b395b04..294bc449e31 100644 --- a/cub/cub/thread/thread_reduce.cuh +++ b/cub/cub/thread/thread_reduce.cuh @@ -627,7 +627,8 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input& ::cuda::minimum<>, ::cuda::minimum, cub::internal::SimdMin, - cub::internal::SimdMax>()) + cub::internal::SimdMax>() + || sizeof(ValueT) >= 8) { return cub::internal::ThreadReduceSequential(input, reduction_op); } From 0e37b11b0e476275f17072caa4aa04ec06169123 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Wed, 1 Jan 2025 06:17:49 -0500 Subject: [PATCH 2/4] cuda.parallel: In-memory caching of build objects (#3216) * Define __eq__ and __hash__ for Iterators * Define cache_with_key utility and use it to cache Reduce objects * Add tests for caching Reduce objects * Tighten up types * Updates to support 3.7 * Address review feedback * Introduce IteratorKind to hold iterator type information * Use the .kind to generate an abi_name * Remove __eq__ and __hash__ methods from IteratorBase * Move helper function * Formatting * Don't unpack tuple in cache key --------- Co-authored-by: Ashwin Srinath --- .../cuda/parallel/experimental/_caching.py | 64 ++++ .../parallel/experimental/_utils/__init__.py | 0 .../cuda/parallel/experimental/_utils/cai.py | 16 + .../experimental/algorithms/reduce.py | 37 ++- .../experimental/iterators/_iterators.py | 158 ++++++---- .../cuda/parallel/experimental/typing.py | 12 + python/cuda_parallel/setup.py | 3 +- python/cuda_parallel/tests/test_iterators.py | 71 +++++ python/cuda_parallel/tests/test_reduce.py | 278 ++++++++++++++++-- 9 files changed, 550 insertions(+), 89 deletions(-) create mode 100644 python/cuda_parallel/cuda/parallel/experimental/_caching.py create mode 100644 python/cuda_parallel/cuda/parallel/experimental/_utils/__init__.py create mode 100644 python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py create mode 100644 python/cuda_parallel/cuda/parallel/experimental/typing.py create mode 100644 python/cuda_parallel/tests/test_iterators.py diff --git a/python/cuda_parallel/cuda/parallel/experimental/_caching.py b/python/cuda_parallel/cuda/parallel/experimental/_caching.py new file mode 100644 index 00000000000..2647c1835ee --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/_caching.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import functools +from numba import cuda + + +def cache_with_key(key): + """ + Decorator to cache the result of the decorated function. Uses the + provided `key` function to compute the key for cache lookup. `key` + receives all arguments passed to the function. + + Notes + ----- + The CUDA compute capability of the current device is appended to + the cache key returned by `key`. + """ + + def deco(func): + cache = {} + + @functools.wraps(func) + def inner(*args, **kwargs): + cc = cuda.get_current_device().compute_capability + cache_key = (key(*args, **kwargs), cc) + if cache_key not in cache: + result = func(*args, **kwargs) + cache[cache_key] = result + return cache[cache_key] + + return inner + + return deco + + +class CachableFunction: + """ + A type that wraps a function and provides custom comparison + (__eq__) and hash (__hash__) implementations. + + The purpose of this class is to enable caching and comparison of + functions based on their bytecode, constants, and closures, while + ignoring other attributes such as their names or docstrings. + """ + + def __init__(self, func): + self._func = func + self._identity = ( + self._func.__code__.co_code, + self._func.__code__.co_consts, + self._func.__closure__, + ) + + def __eq__(self, other): + return self._identity == other._identity + + def __hash__(self): + return hash(self._identity) + + def __repr__(self): + return str(self._func) diff --git a/python/cuda_parallel/cuda/parallel/experimental/_utils/__init__.py b/python/cuda_parallel/cuda/parallel/experimental/_utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py b/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py new file mode 100644 index 00000000000..9c0718e71f0 --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Utilities for extracting information from `__cuda_array_interface__`. +""" + +import numpy as np + +from ..typing import DeviceArrayLike + + +def get_dtype(arr: DeviceArrayLike) -> np.dtype: + return np.dtype(arr.__cuda_array_interface__["typestr"]) diff --git a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py index 29e7786b5f8..41843742827 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py +++ b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from __future__ import annotations # TODO: required for Python 3.7 docs env + import ctypes import numba import numpy as np @@ -12,6 +14,10 @@ from .. import _cccl as cccl from .._bindings import get_paths, get_bindings +from .._caching import CachableFunction, cache_with_key +from ..typing import DeviceArrayLike +from ..iterators._iterators import IteratorBase +from .._utils import cai class _Op: @@ -41,12 +47,18 @@ def _dtype_validation(dt1, dt2): class _Reduce: # TODO: constructor shouldn't require concrete `d_in`, `d_out`: - def __init__(self, d_in, d_out, op: Callable, h_init: np.ndarray): + def __init__( + self, + d_in: DeviceArrayLike | IteratorBase, + d_out: DeviceArrayLike, + op: Callable, + h_init: np.ndarray, + ): d_in_cccl = cccl.to_cccl_iter(d_in) self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( d_in_cccl.value_type.type.value ) - self._ctor_d_out_dtype = d_out.dtype + self._ctor_d_out_dtype = cai.get_dtype(d_out) self._ctor_init_dtype = h_init.dtype cc_major, cc_minor = cuda.get_current_device().compute_capability cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() @@ -119,9 +131,28 @@ def __del__(self): bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result)) +def make_cache_key( + d_in: DeviceArrayLike | IteratorBase, + d_out: DeviceArrayLike, + op: Callable, + h_init: np.ndarray, +): + d_in_key = d_in.kind if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in) + d_out_key = cai.get_dtype(d_out) + op_key = CachableFunction(op) + h_init_key = h_init.dtype + return (d_in_key, d_out_key, op_key, h_init_key) + + # TODO Figure out `sum` without operator and initial value # TODO Accept stream -def reduce_into(d_in, d_out, op: Callable, h_init: np.ndarray): +@cache_with_key(make_cache_key) +def reduce_into( + d_in: DeviceArrayLike | IteratorBase, + d_out: DeviceArrayLike, + op: Callable, + h_init: np.ndarray, +): """Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``. Example: diff --git a/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py b/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py index e92578efbe7..babc4a12a11 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py +++ b/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py @@ -1,15 +1,19 @@ import ctypes import operator +import uuid from functools import lru_cache from typing import Dict, Callable from llvmlite import ir from numba.core.extending import intrinsic, overload from numba.core.typing.ctypes_utils import to_ctypes +from numba.cuda.dispatcher import CUDADispatcher from numba import cuda, types import numba import numpy as np +from .._caching import CachableFunction + _DEVICE_POINTER_SIZE = 8 _DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8 @@ -20,16 +24,35 @@ def cached_compile(func, sig, abi_name=None, **kwargs): return cuda.compile(func, sig, abi_info={"abi_name": abi_name}, **kwargs) +class IteratorKind: + def __init__(self, value_type): + self.value_type = value_type + + def __repr__(self): + return f"{self.__class__.__name__}[{str(self.value_type)}]" + + def __eq__(self, other): + return type(self) is type(other) and self.value_type == other.value_type + + def __hash__(self): + return hash(self.value_type) + + +@lru_cache(maxsize=None) +def _get_abi_suffix(kind: IteratorKind): + # given an IteratorKind, return a UUID. The value is cached so + # that the same UUID is always returned for a given IteratorKind. + return uuid.uuid4().hex + + class IteratorBase: """ An Iterator is a wrapper around a pointer, and must define the following: - - a `state` property that returns a `ctypes.c_void_p` object, representing - a pointer to some data. - - an `advance` (static) method that receives the state pointer and performs + - an `advance` (static) method that receives the pointer and performs an action that advances the pointer by the offset `distance` (returns nothing). - - a `dereference` (static) method that dereferences the state pointer + - a `dereference` (static) method that dereferences the pointer and returns a value. Iterators are not meant to be used directly. They are constructed and passed @@ -38,29 +61,40 @@ class IteratorBase: The `advance` and `dereference` must be compilable to device code by numba. """ - def __init__(self, numba_type: types.Type, value_type: types.Type, abi_name: str): + iterator_kind_type: type # must be a subclass of IteratorKind + + def __init__( + self, + cvalue: ctypes.c_void_p, + numba_type: types.Type, + value_type: types.Type, + ): """ Parameters ---------- + cvalue + A ctypes type representing the object pointed to by the iterator. numba_type - A numba type that specifies how to interpret the state pointer. + A numba type representing the type of the input to the advance + and dereference functions. value_type The numba type of the value returned by the dereference operation. - abi_name - A unique identifier that will determine the abi_names for the - advance and dereference operations. """ + self.cvalue = cvalue self.numba_type = numba_type self.value_type = value_type - self.abi_name = abi_name + + @property + def kind(self): + return self.__class__.iterator_kind_type(self.value_type) # TODO: should we cache this? Current docs environment doesn't allow # using Python > 3.7. We could use a hand-rolled cached_property if # needed. @property def ltoirs(self) -> Dict[str, bytes]: - advance_abi_name = self.abi_name + "_advance" - deref_abi_name = self.abi_name + "_dereference" + advance_abi_name = "advance_" + _get_abi_suffix(self.kind) + deref_abi_name = "dereference_" + _get_abi_suffix(self.kind) advance_ltoir, _ = cached_compile( self.__class__.advance, ( @@ -81,7 +115,7 @@ def ltoirs(self) -> Dict[str, bytes]: @property def state(self) -> ctypes.c_void_p: - raise NotImplementedError("Subclasses must override advance staticmethod") + return ctypes.cast(ctypes.pointer(self.cvalue), ctypes.c_void_p) @staticmethod def advance(state, distance): @@ -122,16 +156,20 @@ def impl(ptr, offset): return impl +class RawPointerType(IteratorKind): + pass + + class RawPointer(IteratorBase): - def __init__(self, ptr: int, ntype: types.Type): - value_type = ntype - self._cvalue = ctypes.c_void_p(ptr) + iterator_kind_type = RawPointerType + + def __init__(self, ptr: int, value_type: types.Type): + cvalue = ctypes.c_void_p(ptr) numba_type = types.CPointer(types.CPointer(value_type)) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( + cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -142,13 +180,9 @@ def advance(state, distance): def dereference(state): return state[0][0] - @property - def state(self) -> ctypes.c_void_p: - return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) - -def pointer(container, ntype: types.Type) -> RawPointer: - return RawPointer(container.__cuda_array_interface__["data"][0], ntype) +def pointer(container, value_type: types.Type) -> RawPointer: + return RawPointer(container.__cuda_array_interface__["data"][0], value_type) @intrinsic @@ -172,16 +206,21 @@ def codegen(context, builder, sig, args): return base.dtype(base), codegen +class CacheModifiedPointerType(IteratorKind): + pass + + class CacheModifiedPointer(IteratorBase): + iterator_kind_type = CacheModifiedPointerType + def __init__(self, ptr: int, ntype: types.Type): - self._cvalue = ctypes.c_void_p(ptr) + cvalue = ctypes.c_void_p(ptr) value_type = ntype numba_type = types.CPointer(types.CPointer(value_type)) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( + cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -192,21 +231,22 @@ def advance(state, distance): def dereference(state): return load_cs(state[0]) - @property - def state(self) -> ctypes.c_void_p: - return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) + +class ConstantIteratorKind(IteratorKind): + pass class ConstantIterator(IteratorBase): + iterator_kind_type = ConstantIteratorKind + def __init__(self, value: np.number): value_type = numba.from_dtype(value.dtype) - self._cvalue = to_ctypes(value_type)(value) + cvalue = to_ctypes(value_type)(value) numba_type = types.CPointer(value_type) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( + cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -217,21 +257,22 @@ def advance(state, distance): def dereference(state): return state[0] - @property - def state(self) -> ctypes.c_void_p: - return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) + +class CountingIteratorKind(IteratorKind): + pass class CountingIterator(IteratorBase): + iterator_kind_type = CountingIteratorKind + def __init__(self, value: np.number): value_type = numba.from_dtype(value.dtype) - self._cvalue = to_ctypes(value_type)(value) + cvalue = to_ctypes(value_type)(value) numba_type = types.CPointer(value_type) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( + cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -242,9 +283,13 @@ def advance(state, distance): def dereference(state): return state[0] - @property - def state(self) -> ctypes.c_void_p: - return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) + +class TransformIteratorKind(IteratorKind): + def __eq__(self, other): + return type(self) is type(other) and self.value_type == other.value_type + + def __hash__(self): + return hash(self.value_type) def make_transform_iterator(it, op: Callable): @@ -256,31 +301,32 @@ def make_transform_iterator(it, op: Callable): op = cuda.jit(op, device=True) class TransformIterator(IteratorBase): - def __init__(self, it: IteratorBase, op): + iterator_kind_type = TransformIteratorKind + + def __init__(self, it: IteratorBase, op: CUDADispatcher): self._it = it + self._op = CachableFunction(op.py_func) numba_type = it.numba_type - # TODO: the abi name below isn't unique enough when we have e.g., - # two identically named `op` functions with different - # signatures, bytecodes, and/or closure variables. - op_abi_name = f"{self.__class__.__name__}_{op.py_func.__name__}" - # TODO: it would be nice to not need to compile `op` to get # its return type, but there's nothing in the numba API # to do that (yet), _, op_retty = cached_compile( op, (self._it.value_type,), - abi_name=op_abi_name, + abi_name=f"{op.__name__}_{_get_abi_suffix(self.kind)}", output="ltoir", ) value_type = op_retty - abi_name = f"{self.__class__.__name__}_{it.abi_name}_{op_abi_name}" super().__init__( + cvalue=it.cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) + @property + def kind(self): + return self.__class__.iterator_kind_type((self._it.kind, self._op)) + @staticmethod def advance(state, distance): return it_advance(state, distance) @@ -289,8 +335,12 @@ def advance(state, distance): def dereference(state): return op(it_dereference(state)) - @property - def state(self) -> ctypes.c_void_p: - return it.state + def __hash__(self): + return hash((self._it, self._op)) + + def __eq__(self, other): + if not isinstance(other.kind, TransformIteratorKind): + return NotImplemented + return self._it == other._it and self._op == other._op return TransformIterator(it, op) diff --git a/python/cuda_parallel/cuda/parallel/experimental/typing.py b/python/cuda_parallel/cuda/parallel/experimental/typing.py new file mode 100644 index 00000000000..1c4e9c9975f --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/typing.py @@ -0,0 +1,12 @@ +from typing_extensions import ( + Protocol, +) # TODO: typing_extensions required for Python 3.7 docs env + + +class DeviceArrayLike(Protocol): + """ + Objects representing a device array, having a `.__cuda_array_interface__` + attribute. + """ + + __cuda_array_interface__: dict diff --git a/python/cuda_parallel/setup.py b/python/cuda_parallel/setup.py index 40c998fafee..0db6592bb05 100644 --- a/python/cuda_parallel/setup.py +++ b/python/cuda_parallel/setup.py @@ -100,7 +100,8 @@ def build_extension(self, ext): ], packages=find_namespace_packages(include=["cuda.*"]), python_requires=">=3.9", - install_requires=["numba>=0.60.0", "cuda-python", "jinja2"], + # TODO: typing_extensions required for Python 3.7 docs env + install_requires=["numba>=0.60.0", "cuda-python", "jinja2", "typing_extensions"], extras_require={ "test": [ "pytest", diff --git a/python/cuda_parallel/tests/test_iterators.py b/python/cuda_parallel/tests/test_iterators.py new file mode 100644 index 00000000000..3b7910d404d --- /dev/null +++ b/python/cuda_parallel/tests/test_iterators.py @@ -0,0 +1,71 @@ +from cuda.parallel.experimental.iterators import ( + CacheModifiedInputIterator, + ConstantIterator, + CountingIterator, + TransformIterator, +) +import cupy as cp +import numpy as np + + +def test_constant_iterator_equality(): + it1 = ConstantIterator(np.int32(0)) + it2 = ConstantIterator(np.int32(0)) + it3 = ConstantIterator(np.int32(1)) + it4 = ConstantIterator(np.int64(9)) + + assert it1.kind == it2.kind == it3.kind + assert it1.kind != it4.kind + + +def test_counting_iterator_equality(): + it1 = CountingIterator(np.int32(0)) + it2 = CountingIterator(np.int32(0)) + it3 = CountingIterator(np.int32(1)) + it4 = CountingIterator(np.int64(9)) + + assert it1.kind == it2.kind == it3.kind + assert it1.kind != it4.kind + + +def test_cache_modified_input_iterator_equality(): + ary1 = cp.asarray([0, 1, 2], dtype="int32") + ary2 = cp.asarray([3, 4, 5], dtype="int32") + ary3 = cp.asarray([0, 1, 2], dtype="int64") + + it1 = CacheModifiedInputIterator(ary1, "stream") + it2 = CacheModifiedInputIterator(ary1, "stream") + it3 = CacheModifiedInputIterator(ary2, "stream") + it4 = CacheModifiedInputIterator(ary3, "stream") + + assert it1.kind == it2.kind == it3.kind + assert it1.kind != it4.kind + + +def test_equality_transform_iterator(): + def op1(x): + return x + + def op2(x): + return 2 * x + + def op3(x): + return x + + it = CountingIterator(np.int32(0)) + it1 = TransformIterator(it, op1) + it2 = TransformIterator(it, op1) + it3 = TransformIterator(it, op3) + + assert it1.kind == it2.kind == it3.kind + + ary1 = cp.asarray([0, 1, 2]) + ary2 = cp.asarray([3, 4, 5]) + it4 = TransformIterator(ary1, op1) + it5 = TransformIterator(ary1, op1) + it6 = TransformIterator(ary1, op2) + it7 = TransformIterator(ary1, op3) + it8 = TransformIterator(ary2, op1) + + assert it4.kind == it5.kind == it7.kind == it8.kind + assert it4.kind != it6.kind diff --git a/python/cuda_parallel/tests/test_reduce.py b/python/cuda_parallel/tests/test_reduce.py index 0f454e3603b..ce5656f635d 100644 --- a/python/cuda_parallel/tests/test_reduce.py +++ b/python/cuda_parallel/tests/test_reduce.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import cupy as cp -import numpy +import numpy as np import pytest import random import numba.cuda @@ -14,31 +14,29 @@ def random_int(shape, dtype): - return numpy.random.randint(0, 5, size=shape).astype(dtype) + return np.random.randint(0, 5, size=shape).astype(dtype) def type_to_problem_sizes(dtype): - if dtype in [numpy.uint8, numpy.int8]: + if dtype in [np.uint8, np.int8]: return [2, 4, 5, 6] - elif dtype in [numpy.uint16, numpy.int16]: + elif dtype in [np.uint16, np.int16]: return [4, 8, 12, 14] - elif dtype in [numpy.uint32, numpy.int32]: + elif dtype in [np.uint32, np.int32]: return [16, 20, 24, 28] - elif dtype in [numpy.uint64, numpy.int64]: + elif dtype in [np.uint64, np.int64]: return [16, 20, 24, 28] else: raise ValueError("Unsupported dtype") -@pytest.mark.parametrize( - "dtype", [numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64] -) +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64]) def test_device_reduce(dtype): def op(a, b): return a + b init_value = 42 - h_init = numpy.array([init_value], dtype=dtype) + h_init = np.array([init_value], dtype=dtype) d_output = numba.cuda.device_array(1, dtype=dtype) reduce_into = algorithms.reduce_into(d_output, d_output, op, h_init) @@ -47,7 +45,7 @@ def op(a, b): h_input = random_int(num_items, dtype) d_input = numba.cuda.to_device(h_input) temp_storage_size = reduce_into(None, d_input, d_output, None, h_init) - d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=numpy.uint8) + d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=np.uint8) reduce_into(d_temp_storage, d_input, d_output, None, h_init) h_output = d_output.copy_to_host() assert h_output[0] == sum(h_input) + init_value @@ -57,19 +55,19 @@ def test_complex_device_reduce(): def op(a, b): return a + b - h_init = numpy.array([40.0 + 2.0j], dtype=complex) + h_init = np.array([40.0 + 2.0j], dtype=complex) d_output = numba.cuda.device_array(1, dtype=complex) reduce_into = algorithms.reduce_into(d_output, d_output, op, h_init) for num_items in [42, 420000]: - h_input = numpy.random.random(num_items) + 1j * numpy.random.random(num_items) + h_input = np.random.random(num_items) + 1j * np.random.random(num_items) d_input = numba.cuda.to_device(h_input) temp_storage_bytes = reduce_into(None, d_input, d_output, None, h_init) - d_temp_storage = numba.cuda.device_array(temp_storage_bytes, numpy.uint8) + d_temp_storage = numba.cuda.device_array(temp_storage_bytes, np.uint8) reduce_into(d_temp_storage, d_input, d_output, None, h_init) result = d_output.copy_to_host()[0] - expected = numpy.sum(h_input, initial=h_init[0]) + expected = np.sum(h_input, initial=h_init[0]) assert result == pytest.approx(expected) @@ -77,9 +75,9 @@ def test_device_reduce_dtype_mismatch(): def min_op(a, b): return a if a < b else b - dtypes = [numpy.int32, numpy.int64] - h_inits = [numpy.array([], dt) for dt in dtypes] - h_inputs = [numpy.array([], dt) for dt in dtypes] + dtypes = [np.int32, np.int64] + h_inits = [np.array([], dt) for dt in dtypes] + h_inputs = [np.array([], dt) for dt in dtypes] d_outputs = [numba.cuda.device_array(1, dt) for dt in dtypes] d_inputs = [numba.cuda.to_device(h_inp) for h_inp in h_inputs] @@ -109,14 +107,14 @@ def add_op(a, b): expected_result = add_op(expected_result, v) if use_numpy_array: - h_input = numpy.array(l_varr, dtype_inp) + h_input = np.array(l_varr, dtype_inp) d_input = numba.cuda.to_device(h_input) else: d_input = i_input d_output = numba.cuda.device_array(1, dtype_out) # to store device sum - h_init = numpy.array([start_sum_with], dtype_out) + h_init = np.array([start_sum_with], dtype_out) reduce_into = algorithms.reduce_into( d_in=d_input, d_out=d_output, op=add_op, h_init=h_init @@ -125,7 +123,7 @@ def add_op(a, b): temp_storage_size = reduce_into( None, d_in=d_input, d_out=d_output, num_items=len(l_varr), h_init=h_init ) - d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=numpy.uint8) + d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=np.uint8) reduce_into(d_temp_storage, d_input, d_output, len(l_varr), h_init) @@ -168,9 +166,9 @@ def test_device_sum_cache_modified_input_it( ): rng = random.Random(0) l_varr = [rng.randrange(100) for _ in range(num_items)] - dtype_inp = numpy.dtype(supported_value_type) + dtype_inp = np.dtype(supported_value_type) dtype_out = dtype_inp - input_devarr = numba.cuda.to_device(numpy.array(l_varr, dtype=dtype_inp)) + input_devarr = numba.cuda.to_device(np.array(l_varr, dtype=dtype_inp)) i_input = iterators.CacheModifiedInputIterator(input_devarr, modifier="stream") _test_device_sum_with_iterator( l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array @@ -181,7 +179,7 @@ def test_device_sum_constant_it( use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 ): l_varr = [42 for distance in range(num_items)] - dtype_inp = numpy.dtype(supported_value_type) + dtype_inp = np.dtype(supported_value_type) dtype_out = dtype_inp i_input = iterators.ConstantIterator(dtype_inp.type(42)) _test_device_sum_with_iterator( @@ -193,7 +191,7 @@ def test_device_sum_counting_it( use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 ): l_varr = [start_sum_with + distance for distance in range(num_items)] - dtype_inp = numpy.dtype(supported_value_type) + dtype_inp = np.dtype(supported_value_type) dtype_out = dtype_inp i_input = iterators.CountingIterator(dtype_inp.type(start_sum_with)) _test_device_sum_with_iterator( @@ -217,8 +215,8 @@ def test_device_sum_map_mul2_count_it( ): l_varr = [2 * (start_sum_with + distance) for distance in range(num_items)] vtn_out, vtn_inp = value_type_name_pair - dtype_inp = numpy.dtype(vtn_inp) - dtype_out = numpy.dtype(vtn_out) + dtype_inp = np.dtype(vtn_inp) + dtype_out = np.dtype(vtn_out) i_input = iterators.TransformIterator( iterators.CountingIterator(dtype_inp.type(start_sum_with)), mul2 ) @@ -248,8 +246,8 @@ def test_device_sum_map_mul_map_mul_count_it( fac_out * (fac_mid * (start_sum_with + distance)) for distance in range(num_items) ] - dtype_inp = numpy.dtype(vtn_inp) - dtype_out = numpy.dtype(vtn_out) + dtype_inp = np.dtype(vtn_inp) + dtype_out = np.dtype(vtn_out) mul_funcs = {2: mul2, 3: mul3} i_input = iterators.TransformIterator( iterators.TransformIterator( @@ -275,8 +273,8 @@ def test_device_sum_map_mul2_cp_array_it( use_numpy_array, value_type_name_pair, num_items=3, start_sum_with=10 ): vtn_out, vtn_inp = value_type_name_pair - dtype_inp = numpy.dtype(vtn_inp) - dtype_out = numpy.dtype(vtn_out) + dtype_inp = np.dtype(vtn_inp) + dtype_out = np.dtype(vtn_out) rng = random.Random(0) l_d_in = [rng.randrange(100) for _ in range(num_items)] a_d_in = cp.array(l_d_in, dtype_inp) @@ -285,3 +283,221 @@ def test_device_sum_map_mul2_cp_array_it( _test_device_sum_with_iterator( l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array ) + + +def test_reducer_caching(): + def sum_op(x, y): + return x + y + + # inputs are device arrays + reducer_1 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are device arrays of different dtype: + reducer_1 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + cp.zeros(3, dtype="int32"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is not reducer_2 + + # outputs are of different dtype: + reducer_1 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int32"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is not reducer_2 + + # inputs are of same dtype but different size + # (should still use cached reducer): + reducer_1 = algorithms.reduce_into( + cp.zeros(3, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + cp.zeros(5, dtype="int64"), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are counting iterators of the + # same value type: + reducer_1 = algorithms.reduce_into( + iterators.CountingIterator(np.int32(0)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.CountingIterator(np.int32(0)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are counting iterators of different value type: + reducer_1 = algorithms.reduce_into( + iterators.CountingIterator(np.int32(0)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.CountingIterator(np.int64(0)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is not reducer_2 + + def op1(x): + return x + + def op2(x): + return 2 * x + + def op3(x): + return x + + # inputs are TransformIterators + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are TransformIterators with different + # op: + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op2), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is not reducer_2 + + # inputs are TransformIterators with same op + # but different name: + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op3), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + + # inputs are CountingIterators of same kind + # but different state: + reducer_1 = algorithms.reduce_into( + iterators.CountingIterator(np.int32(0)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.CountingIterator(np.int32(1)), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + + assert reducer_1 is reducer_2 + + # inputs are TransformIterators of same kind + # but different state: + ary1 = cp.asarray([0, 1, 2], dtype="int64") + ary2 = cp.asarray([0, 1], dtype="int64") + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(ary1, op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(ary2, op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are TransformIterators of same kind + # but different state: + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(1)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is reducer_2 + + # inputs are TransformIterators with different kind: + reducer_1 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int32(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + reducer_2 = algorithms.reduce_into( + iterators.TransformIterator(iterators.CountingIterator(np.int64(0)), op1), + cp.zeros(1, dtype="int64"), + sum_op, + np.zeros([0], dtype="int64"), + ) + assert reducer_1 is not reducer_2 From f5dddc4ba2b14cf087846409cd06881ac46b98a6 Mon Sep 17 00:00:00 2001 From: Eric Niebler Date: Thu, 2 Jan 2025 02:31:29 -0800 Subject: [PATCH 3/4] Just enough ranges for c++14 `span` (#3211) --- .../cuda/std/__functional/ranges_operations.h | 2 +- .../include/cuda/std/__iterator/concepts.h | 2 +- .../std/__iterator/incrementable_traits.h | 4 +- .../std/__iterator/indirectly_comparable.h | 4 +- .../include/cuda/std/__iterator/iter_move.h | 2 +- .../include/cuda/std/__iterator/iter_swap.h | 4 +- .../cuda/std/__iterator/iterator_traits.h | 11 +- .../include/cuda/std/__iterator/mergeable.h | 2 +- .../include/cuda/std/__iterator/permutable.h | 4 +- .../include/cuda/std/__iterator/projected.h | 4 +- .../cuda/std/__iterator/readable_traits.h | 2 +- .../include/cuda/std/__iterator/sortable.h | 4 +- libcudacxx/include/cuda/std/__ranges/access.h | 20 +- .../include/cuda/std/__ranges/concepts.h | 4 +- .../include/cuda/std/__ranges/dangling.h | 4 +- libcudacxx/include/cuda/std/__ranges/data.h | 4 +- .../cuda/std/__ranges/enable_borrowed_range.h | 2 +- .../include/cuda/std/__ranges/enable_view.h | 4 +- libcudacxx/include/cuda/std/__ranges/size.h | 15 +- libcudacxx/include/cuda/std/__ranges/views.h | 6 +- .../cuda/std/__type_traits/common_reference.h | 4 + .../cuda/std/detail/libcxx/include/span | 160 +++------ .../enable_borrowed_range.compile.pass.cpp | 12 +- ...range_concept_conformance.compile.pass.cpp | 34 +- .../views.span/span.cons/deduct.pass.cpp | 2 - .../span.cons/initializer_list.pass.cpp | 4 +- .../cxx20_iterator_traits.compile.pass.cpp | 318 +++++++++--------- .../iter_reference_t.compile.pass.cpp | 10 +- .../contiguous_iterator_tag.pass.cpp | 7 +- .../indirectly_copyable.compile.pass.cpp | 64 ++-- ...ctly_copyable.subsumption.compile.pass.cpp | 2 +- ...irectly_copyable_storable.compile.pass.cpp | 127 ++++--- ...able_storable.subsumption.compile.pass.cpp | 2 +- .../indirectly_movable.compile.pass.cpp | 56 +-- ...ectly_movable.subsumption.compile.pass.cpp | 2 +- ...directly_movable_storable.compile.pass.cpp | 63 ++-- ...able_storable.subsumption.compile.pass.cpp | 2 +- .../indirectly_swappable.compile.pass.cpp | 16 +- ...tly_swappable.subsumption.compile.pass.cpp | 2 +- .../mergeable.compile.pass.cpp | 95 +++--- .../mergeable.subsumption.compile.pass.cpp | 2 +- .../permutable.compile.pass.cpp | 26 +- .../permutable.subsumption.compile.pass.cpp | 2 +- .../sortable.compile.pass.cpp | 42 +-- .../sortable.subsumption.compile.pass.cpp | 2 +- ...indirect_binary_predicate.compile.pass.cpp | 24 +- ...rect_equivalence_relation.compile.pass.cpp | 24 +- .../indirect_result_t.compile.pass.cpp | 18 +- ...ndirect_strict_weak_order.compile.pass.cpp | 26 +- .../indirect_unary_predicate.compile.pass.cpp | 18 +- .../indirectly_comparable.compile.pass.cpp | 18 +- ...ly_comparable.subsumption.compile.pass.cpp | 4 +- ...y_regular_unary_invocable.compile.pass.cpp | 40 +-- ...ndirectly_unary_invocable.compile.pass.cpp | 40 +-- .../projected/projected.compile.pass.cpp | 50 +-- .../incrementable_traits.compile.pass.cpp | 146 ++++---- .../iter_difference_t.compile.pass.cpp | 38 +-- ...ndirectly_readable_traits.compile.pass.cpp | 152 +++++---- .../readable.traits/iter_value_t.pass.cpp | 40 +-- .../bidirectional_iterator.compile.pass.cpp | 37 +- .../subsumption.compile.pass.cpp | 2 +- .../forward_iterator.compile.pass.cpp | 34 +- .../subsumption.compile.pass.cpp | 2 +- .../incrementable.compile.pass.cpp | 46 +-- .../subsumption.compile.pass.cpp | 2 +- .../input_iterator.compile.pass.cpp | 27 +- .../subsumption.compile.pass.cpp | 2 +- .../input_or_output_iterator.compile.pass.cpp | 69 ++-- .../subsumption.compile.pass.cpp | 2 +- .../output_iterator.compile.pass.cpp | 54 +-- .../contiguous_iterator.compile.pass.cpp | 64 ++-- .../random_access_iterator.compile.pass.cpp | 58 ++-- .../indirectly_readable.compile.pass.cpp | 117 +++---- .../iter_common_reference_t.compile.pass.cpp | 21 +- .../sentinel_for.compile.pass.cpp | 17 +- .../sentinel_for.subsumption.compile.pass.cpp | 2 +- .../sized_sentinel_for.compile.pass.cpp | 37 +- .../weakly_incrementable.compile.pass.cpp | 78 ++--- .../indirectly_writable.compile.pass.cpp | 54 +-- .../iter_move.nodiscard.verify.cpp | 2 +- .../iterator.cust.move/iter_move.pass.cpp | 20 +- .../iter_rvalue_reference_t.compile.pass.cpp | 8 +- .../iterator.cust.swap/iter_swap.pass.cpp | 46 +-- .../std/ranges/range.access/begin.pass.cpp | 198 ++++++----- .../std/ranges/range.access/data.pass.cpp | 227 +++++++------ .../std/ranges/range.access/end.pass.cpp | 179 +++++----- .../std/ranges/range.access/size.pass.cpp | 85 ++--- .../borrowed_range.compile.pass.cpp | 42 +-- ...orrowed_range.subsumption.compile.pass.cpp | 2 +- .../enable_borrowed_range.compile.pass.cpp | 16 +- .../helper_aliases.compile.pass.cpp | 17 +- .../range.range/iterator_t.compile.pass.cpp | 21 +- .../range.range/range.compile.pass.cpp | 14 +- .../range.range/range_size_t.compile.pass.cpp | 28 +- .../range.range/sentinel_t.compile.pass.cpp | 17 +- .../bidirectional_range.compile.pass.cpp | 52 +-- .../common_range.compile.pass.cpp | 74 ++-- .../contiguous_range.compile.pass.cpp | 68 ++-- .../forward_range.compile.pass.cpp | 52 +-- .../input_range.compile.pass.cpp | 60 ++-- .../output_range.compile.pass.cpp | 44 +-- .../random_access_range.compile.pass.cpp | 50 +-- .../viewable_range.compile.pass.cpp | 208 ++++++------ .../range.sized/sized_range.compile.pass.cpp | 44 +-- .../range.view/enable_view.compile.pass.cpp | 145 ++++---- .../range.view/view.compile.pass.cpp | 52 +-- .../range.view/view_base.compile.pass.cpp | 6 +- libcudacxx/test/support/indirectly_readable.h | 12 +- libcudacxx/test/support/test_iterators.h | 61 ++-- libcudacxx/test/support/test_range.h | 14 +- 110 files changed, 2262 insertions(+), 2125 deletions(-) diff --git a/libcudacxx/include/cuda/std/__functional/ranges_operations.h b/libcudacxx/include/cuda/std/__functional/ranges_operations.h index 059eda975f4..b15d3960202 100644 --- a/libcudacxx/include/cuda/std/__functional/ranges_operations.h +++ b/libcudacxx/include/cuda/std/__functional/ranges_operations.h @@ -24,7 +24,7 @@ #include #include -#if _CCCL_STD_VER >= 2017 +#if _CCCL_STD_VER >= 2014 _LIBCUDACXX_BEGIN_NAMESPACE_RANGES _LIBCUDACXX_BEGIN_NAMESPACE_RANGES_ABI diff --git a/libcudacxx/include/cuda/std/__iterator/concepts.h b/libcudacxx/include/cuda/std/__iterator/concepts.h index ef36ad11f9d..bd24b7e3803 100644 --- a/libcudacxx/include/cuda/std/__iterator/concepts.h +++ b/libcudacxx/include/cuda/std/__iterator/concepts.h @@ -254,7 +254,7 @@ concept indirectly_copyable_storable = // Note: indirectly_swappable is located in iter_swap.h to prevent a dependency cycle // (both iter_swap and indirectly_swappable require indirectly_readable). -#elif _CCCL_STD_VER > 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ +#elif _CCCL_STD_VER >= 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ // [iterator.concept.readable] template diff --git a/libcudacxx/include/cuda/std/__iterator/incrementable_traits.h b/libcudacxx/include/cuda/std/__iterator/incrementable_traits.h index 4555b4ae412..d188a4ae66c 100644 --- a/libcudacxx/include/cuda/std/__iterator/incrementable_traits.h +++ b/libcudacxx/include/cuda/std/__iterator/incrementable_traits.h @@ -88,7 +88,7 @@ using iter_difference_t = incrementable_traits>, iterator_traits>>::difference_type; -#elif _CCCL_STD_VER > 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ +#elif _CCCL_STD_VER >= 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ // [incrementable.traits] template @@ -150,7 +150,7 @@ using iter_difference_t = incrementable_traits>, iterator_traits>>::difference_type; -#endif // _CCCL_STD_VER > 2014 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_STD diff --git a/libcudacxx/include/cuda/std/__iterator/indirectly_comparable.h b/libcudacxx/include/cuda/std/__iterator/indirectly_comparable.h index bc5c2b615e5..7646e3dfb94 100644 --- a/libcudacxx/include/cuda/std/__iterator/indirectly_comparable.h +++ b/libcudacxx/include/cuda/std/__iterator/indirectly_comparable.h @@ -33,7 +33,7 @@ template , projected<_Iter2, _Proj2>>; -#elif _CCCL_STD_VER > 2014 +#elif _CCCL_STD_VER >= 2014 // clang-format off @@ -50,7 +50,7 @@ _CCCL_CONCEPT indirectly_comparable = // clang-format on -#endif // _CCCL_STD_VER > 2014 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_STD diff --git a/libcudacxx/include/cuda/std/__iterator/iter_move.h b/libcudacxx/include/cuda/std/__iterator/iter_move.h index 54ce7692c1e..22a13ef33ae 100644 --- a/libcudacxx/include/cuda/std/__iterator/iter_move.h +++ b/libcudacxx/include/cuda/std/__iterator/iter_move.h @@ -33,7 +33,7 @@ _CCCL_DIAG_PUSH _CCCL_DIAG_SUPPRESS_CLANG("-Wvoid-ptr-dereference") -#if _CCCL_STD_VER > 2014 +#if _CCCL_STD_VER >= 2014 // [iterator.cust.move] diff --git a/libcudacxx/include/cuda/std/__iterator/iter_swap.h b/libcudacxx/include/cuda/std/__iterator/iter_swap.h index bafeed69742..d58e2b3740e 100644 --- a/libcudacxx/include/cuda/std/__iterator/iter_swap.h +++ b/libcudacxx/include/cuda/std/__iterator/iter_swap.h @@ -30,7 +30,7 @@ #include #include -#if _CCCL_STD_VER > 2014 +#if _CCCL_STD_VER >= 2014 // [iter.cust.swap] @@ -158,6 +158,6 @@ _CCCL_INLINE_VAR constexpr bool __noexcept_swappable<_I1, _I2, enable_if_t 2014 +#endif // _CCCL_STD_VER >= 2014 #endif // _LIBCUDACXX___ITERATOR_ITER_SWAP_H diff --git a/libcudacxx/include/cuda/std/__iterator/iterator_traits.h b/libcudacxx/include/cuda/std/__iterator/iterator_traits.h index 2168ea2fd5c..095880f7cce 100644 --- a/libcudacxx/include/cuda/std/__iterator/iterator_traits.h +++ b/libcudacxx/include/cuda/std/__iterator/iterator_traits.h @@ -34,7 +34,6 @@ #include #include #include -#include #include #include #include @@ -93,7 +92,7 @@ using iter_reference_t = decltype(*declval<_Tp&>()); template struct _CCCL_TYPE_VISIBILITY_DEFAULT iterator_traits; -#elif _CCCL_STD_VER > 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ +#elif _CCCL_STD_VER >= 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ template using __with_reference = _Tp&; @@ -119,10 +118,10 @@ using iter_reference_t = enable_if_t<__dereferenceable<_Tp>, decltype(*declval<_ template struct _CCCL_TYPE_VISIBILITY_DEFAULT iterator_traits; -#else // ^^^ _CCCL_STD_VER > 2014 ^^^ / vvv _CCCL_STD_VER <= 2014 vvv +#else // ^^^ _CCCL_STD_VER >= 2014 ^^^ / vvv _CCCL_STD_VER < 2014 vvv template struct _CCCL_TYPE_VISIBILITY_DEFAULT iterator_traits; -#endif // _CCCL_STD_VER <= 2014 +#endif // _CCCL_STD_VER < 2014 #if _CCCL_COMPILER(NVRTC) @@ -530,7 +529,7 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT iterator_traits : __iterator_traits<_Ip> using __primary_template = iterator_traits; }; -#elif _CCCL_STD_VER > 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ / vvv _CCCL_STD_VER > 2014 vvv +#elif _CCCL_STD_VER >= 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ / vvv _CCCL_STD_VER > 2014 vvv // The `cpp17-*-iterator` exposition-only concepts have very similar names to the `Cpp17*Iterator` named requirements // from `[iterator.cpp17]`. To avoid confusion between the two, the exposition-only concepts have been banished to @@ -860,7 +859,7 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT iterator_traits<_Tp*> typedef _Tp* pointer; typedef typename add_lvalue_reference<_Tp>::type reference; typedef random_access_iterator_tag iterator_category; -#if _CCCL_STD_VER > 2014 +#if _CCCL_STD_VER >= 2014 typedef contiguous_iterator_tag iterator_concept; #endif }; diff --git a/libcudacxx/include/cuda/std/__iterator/mergeable.h b/libcudacxx/include/cuda/std/__iterator/mergeable.h index 7c788375bc6..62a9a90a662 100644 --- a/libcudacxx/include/cuda/std/__iterator/mergeable.h +++ b/libcudacxx/include/cuda/std/__iterator/mergeable.h @@ -41,7 +41,7 @@ concept mergeable = && indirectly_copyable<_Input1, _Output> && indirectly_copyable<_Input2, _Output> && indirect_strict_weak_order<_Comp, projected<_Input1, _Proj1>, projected<_Input2, _Proj2>>; -#elif _CCCL_STD_VER > 2014 +#elif _CCCL_STD_VER >= 2014 template _CCCL_CONCEPT_FRAGMENT( diff --git a/libcudacxx/include/cuda/std/__iterator/permutable.h b/libcudacxx/include/cuda/std/__iterator/permutable.h index 599ede609ef..36968925759 100644 --- a/libcudacxx/include/cuda/std/__iterator/permutable.h +++ b/libcudacxx/include/cuda/std/__iterator/permutable.h @@ -32,7 +32,7 @@ template concept permutable = forward_iterator<_Iterator> && indirectly_movable_storable<_Iterator, _Iterator> && indirectly_swappable<_Iterator, _Iterator>; -#elif _CCCL_STD_VER > 2014 +#elif _CCCL_STD_VER >= 2014 template _CCCL_CONCEPT_FRAGMENT(__permutable_, @@ -43,7 +43,7 @@ _CCCL_CONCEPT_FRAGMENT(__permutable_, template _CCCL_CONCEPT permutable = _CCCL_FRAGMENT(__permutable_, _Iterator); -#endif // _CCCL_STD_VER > 2014 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_STD diff --git a/libcudacxx/include/cuda/std/__iterator/projected.h b/libcudacxx/include/cuda/std/__iterator/projected.h index e8639b48d3b..d65eb462483 100644 --- a/libcudacxx/include/cuda/std/__iterator/projected.h +++ b/libcudacxx/include/cuda/std/__iterator/projected.h @@ -27,7 +27,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD -#if _CCCL_STD_VER > 2014 +#if _CCCL_STD_VER >= 2014 template struct __projected_impl @@ -54,7 +54,7 @@ _CCCL_TEMPLATE(class _It, class _Proj) _CCCL_REQUIRES(indirectly_readable<_It> _CCCL_AND indirectly_regular_unary_invocable<_Proj, _It>) using projected = typename __projected_impl<_It, _Proj>::__type; -#endif // _CCCL_STD_VER > 2014 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_STD diff --git a/libcudacxx/include/cuda/std/__iterator/readable_traits.h b/libcudacxx/include/cuda/std/__iterator/readable_traits.h index b73086dd968..8e5a1266d55 100644 --- a/libcudacxx/include/cuda/std/__iterator/readable_traits.h +++ b/libcudacxx/include/cuda/std/__iterator/readable_traits.h @@ -106,7 +106,7 @@ using iter_value_t = indirectly_readable_traits>, iterator_traits>>::value_type; -#elif _CCCL_STD_VER > 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ +#elif _CCCL_STD_VER >= 2014 // ^^^ !_CCCL_NO_CONCEPTS ^^^ // [readable.traits] template diff --git a/libcudacxx/include/cuda/std/__iterator/sortable.h b/libcudacxx/include/cuda/std/__iterator/sortable.h index 51cd2b00398..9656add726e 100644 --- a/libcudacxx/include/cuda/std/__iterator/sortable.h +++ b/libcudacxx/include/cuda/std/__iterator/sortable.h @@ -34,7 +34,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD template concept sortable = permutable<_Iter> && indirect_strict_weak_order<_Comp, projected<_Iter, _Proj>>; -#elif _CCCL_STD_VER > 2014 +#elif _CCCL_STD_VER >= 2014 template _CCCL_CONCEPT_FRAGMENT( @@ -44,7 +44,7 @@ _CCCL_CONCEPT_FRAGMENT( template _CCCL_CONCEPT sortable = _CCCL_FRAGMENT(__sortable_, _Iter, _Comp, _Proj); -#endif // _CCCL_STD_VER > 2014 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_STD diff --git a/libcudacxx/include/cuda/std/__ranges/access.h b/libcudacxx/include/cuda/std/__ranges/access.h index 3c5ef7da52b..c6ba238ea41 100644 --- a/libcudacxx/include/cuda/std/__ranges/access.h +++ b/libcudacxx/include/cuda/std/__ranges/access.h @@ -33,7 +33,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER > 2014 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 template _CCCL_CONCEPT __can_borrow = is_lvalue_reference_v<_Tp> || enable_borrowed_range>; @@ -120,6 +120,14 @@ struct __fn _CCCL_TEMPLATE(class _Tp) _CCCL_REQUIRES((!__member_begin<_Tp>) _CCCL_AND(!__unqualified_begin<_Tp>)) void operator()(_Tp&&) const = delete; + +# if _CCCL_COMPILER(MSVC, <, 19, 23) + template + void operator()(_Tp (&&)[]) const = delete; + + template + void operator()(_Tp (&&)[_Np]) const = delete; +# endif // _CCCL_COMPILER(MSVC, <, 19, 23) }; _LIBCUDACXX_END_NAMESPACE_CPO @@ -209,6 +217,14 @@ struct __fn _CCCL_TEMPLATE(class _Tp) _CCCL_REQUIRES((!__member_end<_Tp>) _CCCL_AND(!__unqualified_end<_Tp>)) void operator()(_Tp&&) const = delete; + +# if _CCCL_COMPILER(MSVC, <, 19, 23) + template + void operator()(_Tp (&&)[]) const = delete; + + template + void operator()(_Tp (&&)[_Np]) const = delete; +# endif // _CCCL_COMPILER(MSVC, <, 19, 23) }; _LIBCUDACXX_END_NAMESPACE_CPO @@ -279,7 +295,7 @@ inline namespace __cpo { _CCCL_GLOBAL_CONSTANT auto cend = __cend::__fn{}; } // namespace __cpo -#endif // _CCCL_STD_VER > 2014 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/concepts.h b/libcudacxx/include/cuda/std/__ranges/concepts.h index 4183f423ea6..3ba9a23abb9 100644 --- a/libcudacxx/include/cuda/std/__ranges/concepts.h +++ b/libcudacxx/include/cuda/std/__ranges/concepts.h @@ -44,7 +44,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 # if !defined(_CCCL_NO_CONCEPTS) @@ -302,7 +302,7 @@ template _CCCL_CONCEPT __container_compatible_range = _CCCL_FRAGMENT(__container_compatible_range_, _Range, _Tp); # endif // _CCCL_NO_CONCEPTS -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/dangling.h b/libcudacxx/include/cuda/std/__ranges/dangling.h index b97e5e5555a..7e99382cb3a 100644 --- a/libcudacxx/include/cuda/std/__ranges/dangling.h +++ b/libcudacxx/include/cuda/std/__ranges/dangling.h @@ -27,7 +27,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 struct dangling { @@ -47,7 +47,7 @@ using borrowed_iterator_t = enable_if_t, _If, ite // borrowed_subrange_t defined in <__ranges/subrange.h> -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/data.h b/libcudacxx/include/cuda/std/__ranges/data.h index 0f756d52a9f..a9c5db6f085 100644 --- a/libcudacxx/include/cuda/std/__ranges/data.h +++ b/libcudacxx/include/cuda/std/__ranges/data.h @@ -34,7 +34,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 // [range.prim.data] @@ -128,7 +128,7 @@ inline namespace __cpo _CCCL_GLOBAL_CONSTANT auto cdata = __cdata::__fn{}; } // namespace __cpo -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/enable_borrowed_range.h b/libcudacxx/include/cuda/std/__ranges/enable_borrowed_range.h index f0c9a58e679..79d6c23b8da 100644 --- a/libcudacxx/include/cuda/std/__ranges/enable_borrowed_range.h +++ b/libcudacxx/include/cuda/std/__ranges/enable_borrowed_range.h @@ -25,7 +25,7 @@ # pragma system_header #endif // no system header -#if _CCCL_STD_VER > 2014 +#if _CCCL_STD_VER >= 2014 _LIBCUDACXX_BEGIN_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/enable_view.h b/libcudacxx/include/cuda/std/__ranges/enable_view.h index 72e390c0499..315319387d2 100644 --- a/libcudacxx/include/cuda/std/__ranges/enable_view.h +++ b/libcudacxx/include/cuda/std/__ranges/enable_view.h @@ -30,7 +30,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER >= 2017 +#if _CCCL_STD_VER >= 2014 struct view_base {}; @@ -74,7 +74,7 @@ _CCCL_INLINE_VAR constexpr bool true; # endif // _CCCL_NO_CONCEPTS -#endif // _CCCL_STD_VER >= 2017 +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/size.h b/libcudacxx/include/cuda/std/__ranges/size.h index 0b432ae6e87..92d41a62052 100644 --- a/libcudacxx/include/cuda/std/__ranges/size.h +++ b/libcudacxx/include/cuda/std/__ranges/size.h @@ -36,7 +36,7 @@ _LIBCUDACXX_BEGIN_NAMESPACE_RANGES -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 template _CCCL_INLINE_VAR constexpr bool disable_sized_range = false; @@ -182,15 +182,8 @@ struct __fn noexcept(noexcept(_CUDA_VRANGES::size(__t))) { using _Signed = make_signed_t; - if constexpr (sizeof(ptrdiff_t) > sizeof(_Signed)) - { - return static_cast(_CUDA_VRANGES::size(__t)); - } - else - { - return static_cast<_Signed>(_CUDA_VRANGES::size(__t)); - } - _CCCL_UNREACHABLE(); + using _Result = conditional_t<(sizeof(ptrdiff_t) > sizeof(_Signed)), ptrdiff_t, _Signed>; + return static_cast<_Result>(_CUDA_VRANGES::size(__t)); } }; _LIBCUDACXX_END_NAMESPACE_CPO @@ -200,7 +193,7 @@ inline namespace __cpo _CCCL_GLOBAL_CONSTANT auto ssize = __ssize::__fn{}; } // namespace __cpo -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 _LIBCUDACXX_END_NAMESPACE_RANGES diff --git a/libcudacxx/include/cuda/std/__ranges/views.h b/libcudacxx/include/cuda/std/__ranges/views.h index 3954877f117..7dae143e8b2 100644 --- a/libcudacxx/include/cuda/std/__ranges/views.h +++ b/libcudacxx/include/cuda/std/__ranges/views.h @@ -21,7 +21,7 @@ # pragma system_header #endif // no system header -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 _LIBCUDACXX_BEGIN_NAMESPACE_VIEWS @@ -29,10 +29,10 @@ _LIBCUDACXX_END_NAMESPACE_VIEWS _LIBCUDACXX_BEGIN_NAMESPACE_STD -namespace views = ranges::views; +namespace views = ranges::views; // NOLINT: misc-unused-alias-decls _LIBCUDACXX_END_NAMESPACE_STD -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 #endif // _LIBCUDACXX___RANGES_VIEWS diff --git a/libcudacxx/include/cuda/std/__type_traits/common_reference.h b/libcudacxx/include/cuda/std/__type_traits/common_reference.h index 6f62a1033ef..7db241807eb 100644 --- a/libcudacxx/include/cuda/std/__type_traits/common_reference.h +++ b/libcudacxx/include/cuda/std/__type_traits/common_reference.h @@ -37,6 +37,8 @@ #include #include +_CCCL_NV_DIAG_SUPPRESS(1384) // warning: pointer converted to bool + _LIBCUDACXX_BEGIN_NAMESPACE_STD // common_reference @@ -253,4 +255,6 @@ struct common_reference _LIBCUDACXX_END_NAMESPACE_STD +_CCCL_NV_DIAG_DEFAULT(1384) + #endif // _LIBCUDACXX___TYPE_TRAITS_COMMON_REFERENCE_H diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/span b/libcudacxx/include/cuda/std/detail/libcxx/include/span index 042d2f029c5..19fdea2f4ce 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/span +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/span @@ -203,18 +203,7 @@ _CCCL_INLINE_VAR constexpr bool __is_std_span> = true; template _CCCL_CONCEPT __span_array_convertible = _CCCL_TRAIT(is_convertible, _From (*)[], _To (*)[]); -template -_CCCL_INLINE_VAR constexpr bool __is_std_initializer_list = false; - -template -_CCCL_INLINE_VAR constexpr bool __is_std_initializer_list> = true; - -// We want to ensure that span interacts nicely with containers that might not have had the ranges treatment -# if defined(__cpp_lib_ranges) && !_CCCL_COMPILER(MSVC2017) -# define _CCCL_SPAN_USES_RANGES -# endif // __cpp_lib_ranges && !_CCCL_COMPILER(MSVC2017) - -# if defined(_CCCL_SPAN_USES_RANGES) +# if !_CCCL_COMPILER(MSVC2017) template _CCCL_CONCEPT_FRAGMENT( __span_compatible_range_, @@ -223,15 +212,43 @@ _CCCL_CONCEPT_FRAGMENT( requires(_CUDA_VRANGES::sized_range<_Range>), requires((_CUDA_VRANGES::borrowed_range<_Range> || _CCCL_TRAIT(is_const, _ElementType))), requires((!_CCCL_TRAIT(is_array, remove_cvref_t<_Range>))), - requires((!__is_std_span> && !__is_std_array> - && !__is_std_initializer_list>) ), + requires((!__is_std_span> && !__is_std_array>) ), requires(_CCCL_TRAIT( is_convertible, remove_reference_t<_CUDA_VRANGES::range_reference_t<_Range>> (*)[], _ElementType (*)[])))); template _CCCL_CONCEPT __span_compatible_range = _CCCL_FRAGMENT(__span_compatible_range_, _Range, _ElementType); -# if _CCCL_STD_VER >= 2020 +# else // // ^^^ !_CCCL_COMPILER(MSVC2017) ^^^ / vvv _CCCL_COMPILER(MSVC2017) vvv + +template +_CCCL_INLINE_VAR constexpr bool __span_compatible_range = false; + +template +_CCCL_INLINE_VAR constexpr bool __span_compatible_range< + _Range, + _ElementType, + void_t< + // // is a contiguous range + // enable_if_t<_CUDA_VRANGES::contiguous_range<_Range>, nullptr_t>, + // // is a sized range + // enable_if_t<_CUDA_VRANGES::sized_range<_Range>, nullptr_t>, + // // is a borrowed range or ElementType is const + // enable_if_t<(_CUDA_VRANGES::borrowed_range<_Range> || _CCCL_TRAIT(is_const, _ElementType)), nullptr_t>, + // is not a C-style array + enable_if_t), nullptr_t>, + // is not a specialization of span + enable_if_t>, nullptr_t>, + // is not a specialization of array + enable_if_t>, nullptr_t>, + // remove_pointer_t(*)[] is convertible to ElementType(*)[] + enable_if_t<_CCCL_TRAIT(is_convertible, + remove_pointer_t()))> (*)[], + _ElementType (*)[]), + nullptr_t>>> = true; +# endif // _CCCL_COMPILER(MSVC2017) + +# if _CCCL_STD_VER >= 2020 template _CCCL_CONCEPT __span_compatible_iterator = contiguous_iterator<_It> && __span_array_convertible>, _Tp>; @@ -239,7 +256,7 @@ _CCCL_CONCEPT __span_compatible_iterator = template _CCCL_CONCEPT __span_compatible_sentinel_for = sized_sentinel_for<_Sentinel, _It> && !_CCCL_TRAIT(is_convertible, _Sentinel, size_t); -# else // ^^^ C++20 ^^^ / vvv C++17 vvv +# else // ^^^ C++20 ^^^ / vvv C++17 vvv template _CCCL_CONCEPT_FRAGMENT(__span_compatible_iterator_, requires()(requires(contiguous_iterator<_It>), @@ -255,33 +272,7 @@ _CCCL_CONCEPT_FRAGMENT( template _CCCL_CONCEPT __span_compatible_sentinel_for = _CCCL_FRAGMENT(__span_compatible_sentinel_for_, _Sentinel, _It); -# endif // _CCCL_STD_VER <= 2017 -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - -template -_CCCL_INLINE_VAR constexpr bool __is_span_compatible_container = false; - -template -_CCCL_INLINE_VAR constexpr bool __is_span_compatible_container< - _Container, - _ElementType, - void_t< - // is not a specialization of span - enable_if_t>, nullptr_t>, - // is not a specialization of array - enable_if_t>, nullptr_t>, - // is not a specialization of array - enable_if_t>, nullptr_t>, - // is_array_v is false, - enable_if_t), nullptr_t>, - // data(cont) and size(cont) are well formed - decltype(_CUDA_VSTD::data(_CUDA_VSTD::declval<_Container&>())), - decltype(_CUDA_VSTD::size(_CUDA_VSTD::declval<_Container&>())), - // remove_pointer_t(*)[] is convertible to ElementType(*)[] - enable_if_t()))> (*)[], - _ElementType (*)[]>::value, - nullptr_t>>> = true; -# endif // !_CCCL_SPAN_USES_RANGES +# endif // _CCCL_STD_VER <= 2017 # if _CCCL_STD_VER >= 2020 @@ -350,7 +341,6 @@ public: _CCCL_HIDE_FROM_ABI span(const span&) noexcept = default; _CCCL_HIDE_FROM_ABI span& operator=(const span&) noexcept = default; -# if defined(_CCCL_SPAN_USES_RANGES) _CCCL_TEMPLATE(class _It) _CCCL_REQUIRES(__span_compatible_iterator<_It, element_type>) _LIBCUDACXX_HIDE_FROM_ABI constexpr explicit span(_It __first, size_type __count) @@ -370,20 +360,6 @@ public: _CCCL_ASSERT(__last - __first == _Extent, "invalid range in span's constructor (iterator, sentinel): last - first != extent"); } -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(pointer __ptr, size_type __count) - : __data_{__ptr} - { - (void) __count; - _CCCL_ASSERT(_Extent == __count, "size mismatch in span's constructor (ptr, len)"); - } - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(pointer __f, pointer __l) - : __data_{__f} - { - (void) __l; - _CCCL_ASSERT(_Extent == distance(__f, __l), "size mismatch in span's constructor (ptr, ptr)"); - } -# endif // !_CCCL_SPAN_USES_RANGES # if _CCCL_COMPILER(NVRTC) || _CCCL_COMPILER(MSVC2017) template = 0> @@ -408,7 +384,6 @@ public: : __data_{__arr.data()} {} -# if defined(_CCCL_SPAN_USES_RANGES) _CCCL_TEMPLATE(class _Range) _CCCL_REQUIRES(__span_compatible_range<_Range, element_type>) _LIBCUDACXX_HIDE_FROM_ABI constexpr explicit span(_Range&& __r) @@ -416,23 +391,6 @@ public: { _CCCL_ASSERT(_CUDA_VRANGES::size(__r) == _Extent, "size mismatch in span's constructor (range)"); } -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - _CCCL_TEMPLATE(class _Container) - _CCCL_REQUIRES(__is_span_compatible_container<_Container, _Tp>) - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(_Container& __c) noexcept(noexcept(_CUDA_VSTD::data(__c))) - : __data_{_CUDA_VSTD::data(__c)} - { - _CCCL_ASSERT(_Extent == _CUDA_VSTD::size(__c), "size mismatch in span's constructor (other span)"); - } - - _CCCL_TEMPLATE(class _Container) - _CCCL_REQUIRES(__is_span_compatible_container<_Container, const _Tp>) - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(const _Container& __c) noexcept(noexcept(_CUDA_VSTD::data(__c))) - : __data_{_CUDA_VSTD::data(__c)} - { - _CCCL_ASSERT(_Extent == _CUDA_VSTD::size(__c), "size mismatch in span's constructor (other span)"); - } -# endif // !_CCCL_SPAN_USES_RANGES _CCCL_TEMPLATE(class _OtherElementType, size_t _Extent2 = _Extent) _CCCL_REQUIRES((_Extent2 != dynamic_extent) _CCCL_AND __span_array_convertible<_OtherElementType, element_type>) @@ -613,7 +571,6 @@ public: _CCCL_HIDE_FROM_ABI span(const span&) noexcept = default; _CCCL_HIDE_FROM_ABI span& operator=(const span&) noexcept = default; -# if defined(_CCCL_SPAN_USES_RANGES) _CCCL_TEMPLATE(class _It) _CCCL_REQUIRES(__span_compatible_iterator<_It, element_type>) _LIBCUDACXX_HIDE_FROM_ABI constexpr span(_It __first, size_type __count) @@ -630,17 +587,6 @@ public: _CCCL_ASSERT(__last - __first >= 0, "invalid range in span's constructor (iterator, sentinel)"); } -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(pointer __ptr, size_type __count) - : __data_{__ptr} - , __size_{__count} - {} - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(pointer __f, pointer __l) - : __data_{__f} - , __size_{static_cast(__l - __f)} - {} -# endif // !_CCCL_SPAN_USES_RANGES - template _LIBCUDACXX_HIDE_FROM_ABI constexpr span(type_identity_t (&__arr)[_Sz]) noexcept : __data_{__arr} @@ -661,28 +607,12 @@ public: , __size_{_Sz} {} -# if defined(_CCCL_SPAN_USES_RANGES) _CCCL_TEMPLATE(class _Range) _CCCL_REQUIRES(__span_compatible_range<_Range, element_type>) _LIBCUDACXX_HIDE_FROM_ABI constexpr span(_Range&& __r) : __data_(_CUDA_VRANGES::data(__r)) , __size_{_CUDA_VRANGES::size(__r)} {} -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - _CCCL_TEMPLATE(class _Container) - _CCCL_REQUIRES(__is_span_compatible_container<_Container, _Tp>) - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(_Container& __c) - : __data_{_CUDA_VSTD::data(__c)} - , __size_{(size_type) _CUDA_VSTD::size(__c)} - {} - - _CCCL_TEMPLATE(class _Container) - _CCCL_REQUIRES(__is_span_compatible_container<_Container, const _Tp>) - _LIBCUDACXX_HIDE_FROM_ABI constexpr span(const _Container& __c) - : __data_{_CUDA_VSTD::data(__c)} - , __size_{(size_type) _CUDA_VSTD::size(__c)} - {} -# endif // !_CCCL_SPAN_USES_RANGES _CCCL_TEMPLATE(class _OtherElementType, size_t _OtherExtent) _CCCL_REQUIRES(__span_array_convertible<_OtherElementType, element_type>) @@ -812,12 +742,12 @@ public: _LIBCUDACXX_HIDE_FROM_ABI span __as_bytes() const noexcept { - return {reinterpret_cast(data()), size_bytes()}; + return span{reinterpret_cast(data()), size_bytes()}; } _LIBCUDACXX_HIDE_FROM_ABI span __as_writable_bytes() const noexcept { - return {reinterpret_cast(data()), size_bytes()}; + return span{reinterpret_cast(data()), size_bytes()}; } private: @@ -853,8 +783,6 @@ _CCCL_HOST_DEVICE span(array<_Tp, _Sz>&) -> span<_Tp, _Sz>; template _CCCL_HOST_DEVICE span(const array<_Tp, _Sz>&) -> span; -# if defined(_CCCL_SPAN_USES_RANGES) - _CCCL_TEMPLATE(class _It, class _EndOrSize) _CCCL_REQUIRES(contiguous_iterator<_It>) _CCCL_HOST_DEVICE span(_It, @@ -864,23 +792,11 @@ _CCCL_TEMPLATE(class _Range) _CCCL_REQUIRES(_CUDA_VRANGES::contiguous_range<_Range>) _CCCL_HOST_DEVICE span(_Range&&) -> span>>; -# else // ^^^ _CCCL_SPAN_USES_RANGES ^^^ / vvv !_CCCL_SPAN_USES_RANGES vvv - -_CCCL_TEMPLATE(class _Container) -_CCCL_REQUIRES(__is_span_compatible_container<_Container, typename _Container::value_type>) -_CCCL_HOST_DEVICE span(_Container&) -> span; - -_CCCL_TEMPLATE(class _Container) -_CCCL_REQUIRES(__is_span_compatible_container<_Container, const typename _Container::value_type>) -_CCCL_HOST_DEVICE span(const _Container&) -> span; - -# endif // !_CCCL_SPAN_USES_RANGES - #endif // _CCCL_STD_VER >= 2017 _LIBCUDACXX_END_NAMESPACE_STD -#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#if _CCCL_STD_VER >= 2014 && !_CCCL_COMPILER(MSVC2017) _LIBCUDACXX_BEGIN_NAMESPACE_RANGES template _CCCL_INLINE_VAR constexpr bool enable_borrowed_range> = true; @@ -888,6 +804,6 @@ _CCCL_INLINE_VAR constexpr bool enable_borrowed_range> = true template _CCCL_INLINE_VAR constexpr bool enable_view> = true; _LIBCUDACXX_END_NAMESPACE_RANGES -#endif // _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017) +#endif // _CCCL_STD_VER >= 2014 && !_CCCL_COMPILER(MSVC2017) #endif // _LIBCUDACXX_SPAN diff --git a/libcudacxx/test/libcudacxx/std/containers/views/views.span/enable_borrowed_range.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/containers/views/views.span/enable_borrowed_range.compile.pass.cpp index 93f9165ae59..81b36ef1d43 100644 --- a/libcudacxx/test/libcudacxx/std/containers/views/views.span/enable_borrowed_range.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/containers/views/views.span/enable_borrowed_range.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // UNSUPPORTED: msvc-19.16 // @@ -22,11 +22,11 @@ int main(int, char**) { - static_assert(cuda::std::ranges::enable_borrowed_range>); - static_assert(cuda::std::ranges::enable_borrowed_range>); - static_assert(cuda::std::ranges::enable_borrowed_range>); - static_assert(!cuda::std::ranges::enable_borrowed_range&>); - static_assert(!cuda::std::ranges::enable_borrowed_range const>); + static_assert(cuda::std::ranges::enable_borrowed_range>, ""); + static_assert(cuda::std::ranges::enable_borrowed_range>, ""); + static_assert(cuda::std::ranges::enable_borrowed_range>, ""); + static_assert(!cuda::std::ranges::enable_borrowed_range&>, ""); + static_assert(!cuda::std::ranges::enable_borrowed_range const>, ""); return 0; } diff --git a/libcudacxx/test/libcudacxx/std/containers/views/views.span/range_concept_conformance.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/containers/views/views.span/range_concept_conformance.compile.pass.cpp index d0959bf08c3..08ba0d47aaa 100644 --- a/libcudacxx/test/libcudacxx/std/containers/views/views.span/range_concept_conformance.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/containers/views/views.span/range_concept_conformance.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // UNSUPPORTED: msvc-19.16 // span @@ -18,23 +18,23 @@ using range = cuda::std::span; -static_assert(cuda::std::same_as, range::iterator>); -static_assert(cuda::std::ranges::common_range); -static_assert(cuda::std::ranges::random_access_range); -static_assert(cuda::std::ranges::contiguous_range); -static_assert(cuda::std::ranges::view && cuda::std::ranges::enable_view); -static_assert(cuda::std::ranges::sized_range); -static_assert(cuda::std::ranges::borrowed_range); -static_assert(cuda::std::ranges::viewable_range); +static_assert(cuda::std::same_as, range::iterator>, ""); +static_assert(cuda::std::ranges::common_range, ""); +static_assert(cuda::std::ranges::random_access_range, ""); +static_assert(cuda::std::ranges::contiguous_range, ""); +static_assert(cuda::std::ranges::view && cuda::std::ranges::enable_view, ""); +static_assert(cuda::std::ranges::sized_range, ""); +static_assert(cuda::std::ranges::borrowed_range, ""); +static_assert(cuda::std::ranges::viewable_range, ""); -static_assert(cuda::std::same_as, range::iterator>); -static_assert(cuda::std::ranges::common_range); -static_assert(cuda::std::ranges::random_access_range); -static_assert(cuda::std::ranges::contiguous_range); -static_assert(!cuda::std::ranges::view && !cuda::std::ranges::enable_view); -static_assert(cuda::std::ranges::sized_range); -static_assert(cuda::std::ranges::borrowed_range); -static_assert(cuda::std::ranges::viewable_range); +static_assert(cuda::std::same_as, range::iterator>, ""); +static_assert(cuda::std::ranges::common_range, ""); +static_assert(cuda::std::ranges::random_access_range, ""); +static_assert(cuda::std::ranges::contiguous_range, ""); +static_assert(!cuda::std::ranges::view && !cuda::std::ranges::enable_view, ""); +static_assert(cuda::std::ranges::sized_range, ""); +static_assert(cuda::std::ranges::borrowed_range, ""); +static_assert(cuda::std::ranges::viewable_range, ""); int main(int, char**) { diff --git a/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/deduct.pass.cpp b/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/deduct.pass.cpp index 6f3434b5013..c6fe12b759b 100644 --- a/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/deduct.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/deduct.pass.cpp @@ -52,7 +52,6 @@ __host__ __device__ void test_iterator_sentinel() assert(s.data() == cuda::std::data(arr)); } -#if defined(_CCCL_SPAN_USES_RANGES) // P3029R1: deduction from `integral_constant` { cuda::std::span s{cuda::std::begin(arr), cuda::std::integral_constant{}}; @@ -60,7 +59,6 @@ __host__ __device__ void test_iterator_sentinel() assert(s.size() == cuda::std::size(arr)); assert(s.data() == cuda::std::data(arr)); } -#endif // _CCCL_SPAN_USES_RANGES } __host__ __device__ void test_c_array() diff --git a/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/initializer_list.pass.cpp b/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/initializer_list.pass.cpp index d84d0b01115..3f5990de3e8 100644 --- a/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/initializer_list.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/containers/views/views.span/span.cons/initializer_list.pass.cpp @@ -28,8 +28,8 @@ using cuda::std::is_constructible; // Constructor constrains static_assert(is_constructible, cuda::std::initializer_list>::value, ""); static_assert(is_constructible, cuda::std::initializer_list>::value, ""); -static_assert(!is_constructible, cuda::std::initializer_list>::value, ""); -static_assert(!is_constructible, cuda::std::initializer_list>::value, ""); +static_assert(is_constructible, cuda::std::initializer_list>::value, ""); +static_assert(is_constructible, cuda::std::initializer_list>::value, ""); static_assert(!is_constructible, cuda::std::initializer_list>::value, ""); static_assert(!is_constructible, cuda::std::initializer_list>::value, ""); diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/cxx20_iterator_traits.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/cxx20_iterator_traits.compile.pass.cpp index 87bb2900fc8..1ce2ee9a691 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/cxx20_iterator_traits.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/cxx20_iterator_traits.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // template // struct iterator_traits; @@ -26,32 +26,33 @@ #include "test_macros.h" template -inline constexpr bool has_iterator_concept_v = false; +_CCCL_INLINE_VAR constexpr bool has_iterator_concept_v = false; template -inline constexpr bool has_iterator_concept_v> = true; +_CCCL_INLINE_VAR constexpr bool has_iterator_concept_v> = + true; template , int> = 0> __host__ __device__ constexpr void test_iter_concept() { - static_assert(cuda::std::same_as); + static_assert(cuda::std::same_as, ""); } template , int> = 0> __host__ __device__ constexpr void test_iter_concept() { - static_assert(!has_iterator_concept_v); + static_assert(!has_iterator_concept_v, ""); } template __host__ __device__ constexpr bool test() { using Traits = cuda::std::iterator_traits; - static_assert(cuda::std::same_as); - static_assert(cuda::std::same_as); - static_assert(cuda::std::same_as); - static_assert(cuda::std::same_as); - static_assert(cuda::std::same_as); + static_assert(cuda::std::same_as, ""); + static_assert(cuda::std::same_as, ""); + static_assert(cuda::std::same_as, ""); + static_assert(cuda::std::same_as, ""); + static_assert(cuda::std::same_as, ""); test_iter_concept(); @@ -82,8 +83,8 @@ __host__ __device__ constexpr bool testMutable() // exists for any particular non-pointer type, we assume it is present // only for pointers. // -static_assert(testMutable::iterator, cuda::std::random_access_iterator_tag, int>()); -static_assert(testConst::const_iterator, cuda::std::random_access_iterator_tag, int>()); +static_assert(testMutable::iterator, cuda::std::random_access_iterator_tag, int>(), ""); +static_assert(testConst::const_iterator, cuda::std::random_access_iterator_tag, int>(), ""); // Local test iterators. @@ -101,12 +102,12 @@ struct AllMembers {}; }; using AllMembersTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct NoPointerMember { @@ -122,12 +123,12 @@ struct NoPointerMember __host__ __device__ value_type* operator->() const; }; using NoPointerMemberTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct IterConcept { @@ -146,12 +147,12 @@ struct IterConcept {}; }; using IterConceptTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyInput { @@ -178,12 +179,12 @@ struct cuda::std::incrementable_traits using difference_type = short; }; using LegacyInputTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyInputNoValueType { @@ -209,12 +210,12 @@ struct cuda::std::indirectly_readable_traits using value_type = LegacyInputNoValueType::not_value_type; }; using LegacyInputNoValueTypeTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyForward { @@ -240,12 +241,12 @@ struct cuda::std::incrementable_traits using difference_type = short; // or any signed integral type }; using LegacyForwardTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyBidirectional { @@ -264,12 +265,13 @@ struct LegacyBidirectional __host__ __device__ friend short operator-(LegacyBidirectional, LegacyBidirectional); }; using LegacyBidirectionalTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, + ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); // Almost a random access iterator except it is missing operator-(It, It). struct MinusNotDeclaredIter @@ -326,12 +328,13 @@ struct cuda::std::incrementable_traits using difference_type = short; }; using MinusNotDeclaredIterTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, + ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct WrongSubscriptReturnType { @@ -383,12 +386,12 @@ struct WrongSubscriptReturnType }; using WrongSubscriptReturnTypeTraits = cuda::std::iterator_traits; static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyRandomAccess { @@ -417,12 +420,13 @@ struct LegacyRandomAccess __host__ __device__ friend LegacyRandomAccess operator+(int, LegacyRandomAccess); }; using LegacyRandomAccessTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, + ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyRandomAccessSpaceship { @@ -493,14 +497,14 @@ struct cuda::std::incrementable_traits }; using LegacyRandomAccessSpaceshipTraits = cuda::std::iterator_traits; static_assert( - cuda::std::same_as); + cuda::std::same_as, ""); static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); // For output iterators, value_type, difference_type, and reference may be void. struct BareLegacyOutput @@ -512,12 +516,12 @@ struct BareLegacyOutput __host__ __device__ BareLegacyOutput operator++(int); }; using BareLegacyOutputTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); // The operator- means we get difference_type. struct LegacyOutputWithMinus @@ -531,12 +535,12 @@ struct LegacyOutputWithMinus // Lacking operator==, this is a LegacyIterator but not a LegacyInputIterator. }; using LegacyOutputWithMinusTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyOutputWithMemberTypes { @@ -557,12 +561,13 @@ struct LegacyOutputWithMemberTypes // Since (*it) is not convertible to value_type, this is not a LegacyInputIterator. }; using LegacyOutputWithMemberTypesTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, + ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); struct LegacyRandomAccessSpecialized { @@ -628,88 +633,89 @@ struct cuda::std::iterator_traits }; using LegacyRandomAccessSpecializedTraits = cuda::std::iterator_traits; static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); // Other test iterators. using InputTestIteratorTraits = cuda::std::iterator_traits>; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using OutputTestIteratorTraits = cuda::std::iterator_traits>; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using ForwardTestIteratorTraits = cuda::std::iterator_traits>; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using BidirectionalTestIteratorTraits = cuda::std::iterator_traits>; static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using RandomAccessTestIteratorTraits = cuda::std::iterator_traits>; static_assert( - cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); + cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using ContiguousTestIteratorTraits = cuda::std::iterator_traits>; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, + ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using Cpp17BasicIteratorTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using Cpp17InputIteratorTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); using Cpp17ForwardIteratorTraits = cuda::std::iterator_traits; -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(cuda::std::same_as); -static_assert(!has_iterator_concept_v); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(cuda::std::same_as, ""); +static_assert(!has_iterator_concept_v, ""); int main(int, char**) { diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/iter_reference_t.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/iter_reference_t.compile.pass.cpp index c5f0a019c70..cfb7dfc6c2b 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/iter_reference_t.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/iterator.traits/iter_reference_t.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // template // using iter_reference_t = decltype(*declval()); @@ -17,10 +17,10 @@ #include "test_iterators.h" -static_assert(cuda::std::same_as>, int&>); -static_assert(cuda::std::same_as>, int&>); -static_assert(cuda::std::same_as>, int&>); -static_assert(cuda::std::same_as>, int&>); +static_assert(cuda::std::same_as>, int&>, ""); +static_assert(cuda::std::same_as>, int&>, ""); +static_assert(cuda::std::same_as>, int&>, ""); +static_assert(cuda::std::same_as>, int&>, ""); int main(int, char**) { diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/std.iterator.tags/contiguous_iterator_tag.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/std.iterator.tags/contiguous_iterator_tag.pass.cpp index e57be04cd55..c5764895b62 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/std.iterator.tags/contiguous_iterator_tag.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.primitives/std.iterator.tags/contiguous_iterator_tag.pass.cpp @@ -11,7 +11,7 @@ // struct contiguous_iterator_tag : public random_access_iterator_tag {}; -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 #include #include @@ -23,8 +23,9 @@ int main(int, char**) cuda::std::contiguous_iterator_tag tag; ((void) tag); // Prevent unused warning static_assert( - (cuda::std::is_base_of::value)); - static_assert((!cuda::std::is_base_of::value)); + (cuda::std::is_base_of::value), ""); + static_assert((!cuda::std::is_base_of::value), + ""); return 0; } diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.compile.pass.cpp index 8f1a6b608e6..3f89e8b84fc 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // template // concept indirectly_copyable; @@ -29,43 +29,43 @@ struct CopyOnly }; // Can copy the underlying objects between pointers. -static_assert(cuda::std::indirectly_copyable); -static_assert(cuda::std::indirectly_copyable); +static_assert(cuda::std::indirectly_copyable, ""); +static_assert(cuda::std::indirectly_copyable, ""); // Can't copy if the output pointer is const. -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); // Can copy from a pointer into an array but arrays aren't considered indirectly copyable-from. -static_assert(cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); +static_assert(cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); // Can't copy between non-pointer types. -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); // Check some less common types. -static_assert(!cuda::std::indirectly_movable); -static_assert(!cuda::std::indirectly_movable); -static_assert(!cuda::std::indirectly_movable); -static_assert(!cuda::std::indirectly_movable); -static_assert(!cuda::std::indirectly_movable); +static_assert(!cuda::std::indirectly_movable, ""); +static_assert(!cuda::std::indirectly_movable, ""); +static_assert(!cuda::std::indirectly_movable, ""); +static_assert(!cuda::std::indirectly_movable, ""); +static_assert(!cuda::std::indirectly_movable, ""); // Can't copy move-only objects. -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); // Can copy copy-only objects. #ifndef TEST_COMPILER_MSVC_2017 // MSVC2017 has issues determining common_reference -static_assert(cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); -static_assert(cuda::std::indirectly_copyable); -static_assert(!cuda::std::indirectly_copyable); +static_assert(cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); +static_assert(cuda::std::indirectly_copyable, ""); +static_assert(!cuda::std::indirectly_copyable, ""); #endif // TEST_COMPILER_MSVC_2017 template @@ -77,13 +77,13 @@ struct PointerTo #ifndef TEST_COMPILER_MSVC_2017 // MSVC2017 has issues determining common_reference // Can copy through a dereferenceable class. -static_assert(cuda::std::indirectly_copyable>); -static_assert(!cuda::std::indirectly_copyable>); -static_assert(cuda::std::indirectly_copyable, PointerTo>); -static_assert(!cuda::std::indirectly_copyable, PointerTo>); -static_assert(cuda::std::indirectly_copyable>); -static_assert(cuda::std::indirectly_copyable, CopyOnly*>); -static_assert(cuda::std::indirectly_copyable, PointerTo>); +static_assert(cuda::std::indirectly_copyable>, ""); +static_assert(!cuda::std::indirectly_copyable>, ""); +static_assert(cuda::std::indirectly_copyable, PointerTo>, ""); +static_assert(!cuda::std::indirectly_copyable, PointerTo>, ""); +static_assert(cuda::std::indirectly_copyable>, ""); +static_assert(cuda::std::indirectly_copyable, CopyOnly*>, ""); +static_assert(cuda::std::indirectly_copyable, PointerTo>, ""); #endif // TEST_COMPILER_MSVC_2017 int main(int, char**) diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.subsumption.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.subsumption.compile.pass.cpp index 3f862c463e4..dc97f45d014 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.subsumption.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable.subsumption.compile.pass.cpp @@ -29,7 +29,7 @@ __host__ __device__ constexpr bool indirectly_copyable_subsumption() return true; } -static_assert(indirectly_copyable_subsumption()); +static_assert(indirectly_copyable_subsumption(), ""); int main(int, char**) { diff --git a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable_storable.compile.pass.cpp b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable_storable.compile.pass.cpp index dc63264f480..33d6ed0477f 100644 --- a/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable_storable.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/iterators/iterator.requirements/alg.req.ind.copy/indirectly_copyable_storable.compile.pass.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -// UNSUPPORTED: c++03, c++11, c++14 +// UNSUPPORTED: c++03, c++11 // template // concept indirectly_copyable_storable; @@ -35,17 +35,17 @@ struct PointerTo // Copying the underlying object between pointers (or dereferenceable classes) works. This is a non-exhaustive check // because this functionality comes from `indirectly_copyable`. -static_assert(cuda::std::indirectly_copyable_storable); -static_assert(cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable, PointerTo>); +static_assert(cuda::std::indirectly_copyable_storable, ""); +static_assert(cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, PointerTo>, ""); // `indirectly_copyable_storable` requires the type to be `copyable`, which in turns requires it to be `movable`. -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable, PointerTo>); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, PointerTo>, ""); // The dereference operator returns a different type from `value_type` and the reference type cannot be assigned from a // non-const lvalue of `ValueType` (but all other forms of assignment from `ValueType` work). @@ -70,12 +70,14 @@ struct NoLvalueAssignment __host__ __device__ ReferenceType& operator*() const; }; -static_assert(cuda::std::indirectly_writable>); -static_assert(!cuda::std::indirectly_writable&>); -static_assert(cuda::std::indirectly_writable&>); -static_assert(cuda::std::indirectly_writable&&>); -static_assert(cuda::std::indirectly_writable&&>); -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(cuda::std::indirectly_writable>, ""); +static_assert(!cuda::std::indirectly_writable&>, ""); +static_assert(cuda::std::indirectly_writable&>, + ""); +static_assert(cuda::std::indirectly_writable&&>, ""); +static_assert(cuda::std::indirectly_writable&&>, + ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); // The dereference operator returns a different type from `value_type` and the reference type cannot be assigned from a // const lvalue of `ValueType` (but all other forms of assignment from `ValueType` work). @@ -101,16 +103,18 @@ struct NoConstLvalueAssignment }; static_assert( - cuda::std::indirectly_writable>); + cuda::std::indirectly_writable>, ""); static_assert( - cuda::std::indirectly_writable&>); + cuda::std::indirectly_writable&>, ""); static_assert( - !cuda::std::indirectly_writable&>); + !cuda::std::indirectly_writable&>, + ""); static_assert( - cuda::std::indirectly_writable&&>); + cuda::std::indirectly_writable&&>, ""); static_assert( - cuda::std::indirectly_writable&&>); -static_assert(!cuda::std::indirectly_copyable_storable); + cuda::std::indirectly_writable&&>, + ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); // The dereference operator returns a different type from `value_type` and the reference type cannot be assigned from a // non-const rvalue of `ValueType` (but all other forms of assignment from `ValueType` work). @@ -135,12 +139,14 @@ struct NoRvalueAssignment __host__ __device__ ReferenceType& operator*() const; }; -static_assert(cuda::std::indirectly_writable>); -static_assert(cuda::std::indirectly_writable&>); -static_assert(cuda::std::indirectly_writable&>); -static_assert(!cuda::std::indirectly_writable&&>); -static_assert(cuda::std::indirectly_writable&&>); -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(cuda::std::indirectly_writable>, ""); +static_assert(cuda::std::indirectly_writable&>, ""); +static_assert(cuda::std::indirectly_writable&>, + ""); +static_assert(!cuda::std::indirectly_writable&&>, ""); +static_assert(cuda::std::indirectly_writable&&>, + ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); // The dereference operator returns a different type from `value_type` and the reference type cannot be assigned from a // const rvalue of `ValueType` (but all other forms of assignment from `ValueType` work). @@ -166,16 +172,17 @@ struct NoConstRvalueAssignment }; static_assert( - cuda::std::indirectly_writable>); + cuda::std::indirectly_writable>, ""); static_assert( - cuda::std::indirectly_writable&>); + cuda::std::indirectly_writable&>, ""); static_assert( - cuda::std::indirectly_writable&>); + cuda::std::indirectly_writable&>, ""); static_assert( - cuda::std::indirectly_writable&&>); + cuda::std::indirectly_writable&&>, ""); static_assert( - !cuda::std::indirectly_writable&&>); -static_assert(!cuda::std::indirectly_copyable_storable); + !cuda::std::indirectly_writable&&>, + ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); struct DeletedCopyCtor { @@ -191,7 +198,7 @@ struct DeletedNonconstCopyCtor DeletedNonconstCopyCtor(DeletedNonconstCopyCtor&) = delete; DeletedNonconstCopyCtor& operator=(DeletedNonconstCopyCtor const&) = default; }; -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(!cuda::std::indirectly_copyable_storable, ""); #endif // TEST_STD_VER > 2017 || !defined(TEST_COMPILER_MSVC) struct DeletedMoveCtor @@ -208,7 +215,7 @@ struct DeletedConstMoveCtor DeletedConstMoveCtor(DeletedConstMoveCtor const&&) = delete; DeletedConstMoveCtor& operator=(DeletedConstMoveCtor&&) = default; }; -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(!cuda::std::indirectly_copyable_storable, ""); #endif // TEST_STD_VER > 2017 || !defined(TEST_COMPILER_MSVC) struct DeletedCopyAssignment @@ -225,7 +232,8 @@ struct DeletedNonconstCopyAssignment DeletedNonconstCopyAssignment& operator=(DeletedNonconstCopyAssignment const&) = default; DeletedNonconstCopyAssignment& operator=(DeletedNonconstCopyAssignment&) = delete; }; -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(!cuda::std::indirectly_copyable_storable, + ""); #endif // TEST_STD_VER > 2017 || !defined(TEST_COMPILER_MSVC) struct DeletedMoveAssignment @@ -240,11 +248,11 @@ struct DeletedConstMoveAssignment DeletedConstMoveAssignment& operator=(DeletedConstMoveAssignment&&) = delete; }; -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); -static_assert(!cuda::std::indirectly_copyable_storable); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); +static_assert(!cuda::std::indirectly_copyable_storable, ""); struct InconsistentIterator { @@ -267,7 +275,7 @@ struct InconsistentIterator // `ValueType` can be constructed with a `ReferenceType` and assigned to a `ReferenceType`, so it does model // `indirectly_copyable_storable`. -static_assert(cuda::std::indirectly_copyable_storable); +static_assert(cuda::std::indirectly_copyable_storable, ""); struct CommonType {}; @@ -292,22 +300,27 @@ struct NotConstructibleFromRefIn __host__ __device__ ReferenceType& operator*() const; }; +namespace cuda +{ +namespace std +{ template