Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] std::linalg accessors and transposed_layout #2962

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
dee9385
draft of scaled accessor
fbusato Nov 25, 2024
0a33496
add scaled unit test
fbusato Nov 26, 2024
fb03ff2
Merge branch 'main' into linalg-accessors
fbusato Nov 26, 2024
c5c6fa4
add conjugated accessor
fbusato Nov 26, 2024
ce2cc84
refined scaled accessor implementation
fbusato Nov 26, 2024
1bf05d0
add [[nodiscard]]
fbusato Nov 26, 2024
1a261ff
add transposed function
fbusato Nov 26, 2024
8db4036
add conjugate_transposed
fbusato Nov 26, 2024
f48792a
fix internal names
fbusato Nov 26, 2024
67c59f9
replace inline lambda with function object
fbusato Nov 26, 2024
9d913be
add tests
fbusato Nov 26, 2024
77c2d66
Merge branch 'main' into linalg-accessors
fbusato Nov 26, 2024
3d06282
fix c++20 requires clause
fbusato Nov 27, 2024
c3d1464
Merge branch 'main' into linalg-accessors
fbusato Nov 27, 2024
5390bd1
add __cccl_lib_mdspan check
fbusato Nov 27, 2024
626f63a
prevent to include headers if the compiler is not supported
fbusato Nov 27, 2024
baa26f6
skip double noexcept for old compilers
fbusato Nov 27, 2024
4cce6e6
avoid deduction guides to prevent errors with old gcc versions
fbusato Nov 27, 2024
5031fb6
fix #endif position
fbusato Nov 27, 2024
8d20f84
Merge branch 'main' into linalg-accessors
fbusato Nov 27, 2024
bdee5a3
fix clang9/gcc9 compatibility
fbusato Nov 27, 2024
10c0f0a
remove redundant header
fbusato Nov 27, 2024
feed72e
avoid deduction guides for conjugate_transposed
fbusato Nov 27, 2024
30cf216
fix variable shadowing
fbusato Nov 27, 2024
a39d27a
fix nvrtc header bug
fbusato Nov 27, 2024
9699ce5
Merge branch 'main' into linalg-accessors
fbusato Dec 2, 2024
2da684d
relax noexcept(noexcept()) compiler filtering
fbusato Dec 2, 2024
8746de6
adopt concept for conj_if_needed
fbusato Dec 2, 2024
6445b49
add documentation
fbusato Dec 2, 2024
cf34003
fix compiler identification macro
fbusato Dec 3, 2024
ad47459
Update libcudacxx/include/cuda/std/__linalg/conjugate_transposed.h
fbusato Dec 3, 2024
9c621a5
Update libcudacxx/include/cuda/std/__linalg/conjugated.h
fbusato Dec 3, 2024
6d521d8
Update libcudacxx/include/cuda/std/__linalg/conjugated.h
fbusato Dec 3, 2024
ea0038f
Update libcudacxx/include/cuda/std/__linalg/conjugated.h
fbusato Dec 3, 2024
6f18bee
Update libcudacxx/include/cuda/std/__linalg/conjugate_transposed.h
fbusato Dec 3, 2024
4411e5d
add linalg reference in docs
fbusato Dec 3, 2024
22c37f6
fix documentation
fbusato Dec 3, 2024
fe7d176
add linalg header
fbusato Dec 3, 2024
48483f4
change test header
fbusato Dec 3, 2024
7fb5d60
remove redundant namespace specifications
fbusato Dec 3, 2024
627d354
Merge branch 'main' into linalg-accessors
fbusato Dec 3, 2024
b565ad6
add operator!=
fbusato Dec 3, 2024
5af13ad
fix new linalg documentation position
fbusato Dec 3, 2024
103f1b6
fix c++20 require clause
fbusato Dec 3, 2024
377bf53
fix requires expression again
fbusato Dec 3, 2024
eef1e17
remove forward_like duplication in docs
fbusato Dec 4, 2024
de30054
license update
fbusato Dec 5, 2024
8093083
Merge branch 'main' into linalg-accessors
fbusato Dec 5, 2024
140c843
add `unreachable` in standard_api.rst
fbusato Dec 5, 2024
ec199d0
Merge branch 'main' into linalg-accessors
fbusato Dec 18, 2024
71c7211
add missing constexpr
fbusato Dec 18, 2024
7b529bd
remove duplicate line in docs
fbusato Dec 18, 2024
03e7387
split scaled accessor constructors
fbusato Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/libcudacxx/standard_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ Feature availability:

- C++26 ``std::dims`` is available in C++14.

- C++23 ``forward_like``, ``to_underlying`` and ``unreachable`` from ``<utility>`` are available in C++11.
- C++26 ``std::linalg`` accessors, transposed layout, and related functions are available in C++17.

- ``scaled()`` and ``scaled_accessor``
- ``conjugated()`` and ``conjugated_accessor``
- ``transposed()`` and ``layout_transpose``
- ``conjugate_transposed()``

- C++23 ``forward_like``, ``to_underlying``, and ``unreachable`` from ``<utility>`` are available in C++11.

- C++23 ``is_scoped_enum`` in ``<type_traits>`` is available in C++11.

Expand Down
1 change: 1 addition & 0 deletions docs/libcudacxx/standard_api/numerics_library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Numerics Library
numerics_library/bit
numerics_library/complex
numerics_library/numeric
numerics_library/linalg

Any Standard C++ header not listed below is omitted.

Expand Down
31 changes: 31 additions & 0 deletions docs/libcudacxx/standard_api/numerics_library/linalg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
.. _libcudacxx-standard-api-numerics-linalg:

``<cuda/std/linalg>``
============================================

Provided functionalities
------------------------

- ``scaled()`` `std::linalg::scaled <https://en.cppreference.com/w/cpp/numeric/linalg/scaled>`_
- ``scaled_accessor`` `std::linalg::scaled_accessor <https://en.cppreference.com/w/cpp/numeric/linalg/scaled_accessor>`_
- ``conjugated()`` `std::linalg::conjugated <https://en.cppreference.com/w/cpp/numeric/linalg/conjugated>`_
- ``conjugated_accessor`` `std::linalg::conjugated_accessor <https://en.cppreference.com/w/cpp/numeric/linalg/conjugated_accessor>`_
- ``transposed()`` `std::linalg::transposed <https://en.cppreference.com/w/cpp/numeric/linalg/transposed>`_
- ``layout_transpose`` `std::linalg::layout_transpose <https://en.cppreference.com/w/cpp/numeric/linalg/layout_transpose>`_
- ``conjugate_transposed()`` `std::linalg::conjugate_transposed <https://en.cppreference.com/w/cpp/numeric/linalg/conjugate_transposed>`_

Extensions
----------

- C++26 ``std::linalg`` accessors, transposed layout, and related functions are available in C++17

Omissions
---------

- Currently we do not expose any BLAS functions and layouts.

Restrictions
------------

- On device no exceptions are thrown in case of a bad access.
- MSVC is only supported with C++20
78 changes: 78 additions & 0 deletions libcudacxx/include/cuda/std/__linalg/conj_if_needed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
//===---------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___LINALG_CONJUGATE_IF_NEEDED_HPP
#define _LIBCUDACXX___LINALG_CONJUGATE_IF_NEEDED_HPP

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/version>

#if defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017

# include <cuda/std/__concepts/concept_macros.h>
# include <cuda/std/__type_traits/is_arithmetic.h>
# include <cuda/std/complex>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

namespace linalg
{

_LIBCUDACXX_BEGIN_NAMESPACE_CPO(__conj_if_needed)

template <class _Type>
_CCCL_CONCEPT _HasConj = _CCCL_REQUIRES_EXPR((_Type), _Type __a)(static_cast<void>(_CUDA_VSTD::conj(__a)));

struct __conj_if_needed
{
template <class _Type>
_LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator()(const _Type& __t) const
{
if constexpr (is_arithmetic_v<_Type> || !_HasConj<_Type>)
{
return __t;
}
else
{
return _CUDA_VSTD::conj(__t);
}
_CCCL_UNREACHABLE();
}
};

_LIBCUDACXX_END_NAMESPACE_CPO

inline namespace __cpo
{
_CCCL_GLOBAL_CONSTANT auto conj_if_needed = __conj_if_needed::__conj_if_needed{};

} // namespace __cpo
} // end namespace linalg

_LIBCUDACXX_END_NAMESPACE_STD

#endif // defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017
#endif // _LIBCUDACXX___LINALG_CONJUGATED_HPP
55 changes: 55 additions & 0 deletions libcudacxx/include/cuda/std/__linalg/conjugate_transposed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
//===---------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___LINALG_CONJUGATE_TRANSPOSED_HPP
#define _LIBCUDACXX___LINALG_CONJUGATE_TRANSPOSED_HPP

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/version>

#if defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017

# include <cuda/std/__linalg/conjugated.h>
# include <cuda/std/__linalg/transposed.h>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

namespace linalg
{

template <class _ElementType, class _Extents, class _Layout, class _Accessor>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
conjugate_transposed(mdspan<_ElementType, _Extents, _Layout, _Accessor> __a)
{
return conjugated(transposed(__a));
}

} // end namespace linalg

_LIBCUDACXX_END_NAMESPACE_STD

#endif // defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017
#endif // _LIBCUDACXX___LINALG_CONJUGATE_TRANSPOSED_HPP
141 changes: 141 additions & 0 deletions libcudacxx/include/cuda/std/__linalg/conjugated.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
//===---------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___LINALG_CONJUGATED_HPP
#define _LIBCUDACXX___LINALG_CONJUGATED_HPP

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/version>

#if defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017

# include <cuda/std/__linalg/conj_if_needed.h>
# include <cuda/std/__type_traits/add_const.h>
# include <cuda/std/__type_traits/is_arithmetic.h>
# include <cuda/std/__type_traits/remove_const.h>
# include <cuda/std/__utility/declval.h>
# include <cuda/std/mdspan>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

namespace linalg
{

template <class _NestedAccessor>
class conjugated_accessor
{
private:
using __nested_element_type = typename _NestedAccessor::element_type;
using __nc_result_type = decltype(conj_if_needed(_CUDA_VSTD::declval<__nested_element_type>()));

public:
using element_type = add_const_t<__nc_result_type>;
using reference = remove_const_t<element_type>;
using data_handle_type = typename _NestedAccessor::data_handle_type;
using offset_policy = conjugated_accessor<typename _NestedAccessor::offset_policy>;

_CCCL_HIDE_FROM_ABI constexpr conjugated_accessor() = default;

_LIBCUDACXX_HIDE_FROM_ABI constexpr conjugated_accessor(const _NestedAccessor& __acc)
: __nested_accessor_(__acc)
{}

_CCCL_TEMPLATE(class _OtherNestedAccessor)
_CCCL_REQUIRES(_CCCL_TRAIT(is_constructible, _NestedAccessor, const _OtherNestedAccessor&)
_CCCL_AND _CCCL_TRAIT(is_convertible, _OtherNestedAccessor, _NestedAccessor))
_LIBCUDACXX_HIDE_FROM_ABI constexpr conjugated_accessor(const conjugated_accessor<_OtherNestedAccessor>& __other)
: __nested_accessor_(__other.nested_accessor())
{}

_CCCL_TEMPLATE(class _OtherNestedAccessor)
_CCCL_REQUIRES(_CCCL_TRAIT(is_constructible, _NestedAccessor, const _OtherNestedAccessor&)
_CCCL_AND(!_CCCL_TRAIT(is_convertible, _OtherNestedAccessor, _NestedAccessor)))
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr conjugated_accessor(
const conjugated_accessor<_OtherNestedAccessor>& __other)
: __nested_accessor_(__other.nested_accessor())
{}

_LIBCUDACXX_HIDE_FROM_ABI constexpr reference access(data_handle_type __p, size_t __i) const noexcept
{
return conj_if_needed(__nested_element_type(__nested_accessor_.access(__p, __i)));
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr typename offset_policy::data_handle_type
offset(data_handle_type __p, size_t __i) const noexcept
{
return __nested_accessor_.offset(__p, __i);
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr const _NestedAccessor& nested_accessor() const noexcept
{
return __nested_accessor_;
}

private:
_NestedAccessor __nested_accessor_;
};

template <class _ElementType, class _Extents, class _Layout, class _Accessor>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
conjugated(mdspan<_ElementType, _Extents, _Layout, _Accessor> __a)
{
using __value_type = typename decltype(__a)::value_type;
// Current status of [linalg] only optimizes if _Accessor is conjugated_accessor<_Accessor> for some _Accessor.
// There's __a separate specialization for that case below.

// P3050 optimizes conjugated's accessor type for when we know that it can't be complex: arithmetic types,
// and types for which `conj` is not ADL-findable.
if constexpr (is_arithmetic_v<__value_type> || !__conj_if_needed::_HasConj<__value_type>)
{
return mdspan<_ElementType, _Extents, _Layout, _Accessor>(__a.data_handle(), __a.mapping(), __a.accessor());
}
else
{
using __return_element_type = typename conjugated_accessor<_Accessor>::element_type;
using __return_accessor_type = conjugated_accessor<_Accessor>;
return mdspan<__return_element_type, _Extents, _Layout, __return_accessor_type>{
__a.data_handle(), __a.mapping(), __return_accessor_type(__a.accessor())};
}
_CCCL_UNREACHABLE();
}

// Conjugation is self-annihilating
template <class _ElementType, class _Extents, class _Layout, class _NestedAccessor>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
conjugated(mdspan<_ElementType, _Extents, _Layout, conjugated_accessor<_NestedAccessor>> __a)
{
using __return_element_type = typename _NestedAccessor::element_type;
using __return_accessor_type = _NestedAccessor;
return mdspan<__return_element_type, _Extents, _Layout, __return_accessor_type>(
__a.data_handle(), __a.mapping(), __a.accessor().nested_accessor());
}

} // end namespace linalg

_LIBCUDACXX_END_NAMESPACE_STD

#endif // defined(__cccl_lib_mdspan) && _CCCL_STD_VER >= 2017
#endif // _LIBCUDACXX___LINALG_CONJUGATED_HPP
Loading
Loading