diff --git a/cub/cub/device/dispatch/tuning/tuning_scan.cuh b/cub/cub/device/dispatch/tuning/tuning_scan.cuh index ac5dbfc5868..2163c4b7431 100644 --- a/cub/cub/device/dispatch/tuning/tuning_scan.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_scan.cuh @@ -108,12 +108,16 @@ constexpr accum_size classify_accum_size() : accum_size::unknown; } -template +template struct tuning { static constexpr int threads = Threads; static constexpr int items = Items; using delay_constructor = fixed_delay_constructor_t; + static constexpr BlockLoadAlgorithm load_algorithm = + (sizeof(AccumT) > 128) ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED : BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr BlockStoreAlgorithm store_algorithm = + (sizeof(AccumT) > 128) ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED : BLOCK_STORE_WARP_TRANSPOSE; }; template struct sm90_tuning : tuning<192, 22, 168, 1140> {}; -template struct sm90_tuning : tuning<512, 12, 376, 1125> {}; -template struct sm90_tuning : tuning<128, 24, 648, 1245> {}; -template struct sm90_tuning : tuning<224, 24, 632, 1290> {}; +template struct sm90_tuning : tuning {}; +template struct sm90_tuning : tuning {}; +template struct sm90_tuning : tuning {}; +template struct sm90_tuning : tuning {}; -template <> struct sm90_tuning : tuning<128, 24, 688, 1140> {}; -template <> struct sm90_tuning : tuning<224, 24, 576, 1215> {}; +template <> struct sm90_tuning : tuning {}; +template <> struct sm90_tuning : tuning {}; #if CUB_IS_INT128_ENABLED -template <> struct sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<576, 21, 860, 630> {}; +template <> struct sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<__int128_t, 576, 21, 860, 630> {}; template <> struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>