-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dynamic CUB dispatch for scan to support c.parallel #3398
base: main
Are you sure you want to change the base?
Changes from all commits
bb5744a
34d647b
3ccf70e
b4fe9e9
8a9de0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
#endif // no system header | ||
|
||
#include <cub/agent/agent_scan.cuh> | ||
#include <cub/detail/launcher/cuda_runtime.cuh> | ||
#include <cub/device/dispatch/kernels/scan.cuh> | ||
#include <cub/device/dispatch/tuning/tuning_scan.cuh> | ||
#include <cub/grid/grid_queue.cuh> | ||
|
@@ -62,6 +63,38 @@ | |
|
||
CUB_NAMESPACE_BEGIN | ||
|
||
template <typename MaxPolicyT, | ||
typename InputIteratorT, | ||
typename OutputIteratorT, | ||
typename ScanOpT, | ||
typename InitValueT, | ||
typename OffsetT, | ||
typename AccumT, | ||
bool ForceInclusive> | ||
struct DeviceScanKernelSource | ||
{ | ||
using ScanTileStateT = typename cub::ScanTileState<AccumT>; | ||
|
||
CUB_DEFINE_KERNEL_GETTER(ScanInitKernel, DeviceScanInitKernel<ScanTileStateT>) | ||
|
||
CUB_DEFINE_KERNEL_GETTER( | ||
ScanKernel, | ||
DeviceScanKernel<MaxPolicyT, | ||
InputIteratorT, | ||
OutputIteratorT, | ||
ScanTileStateT, | ||
ScanOpT, | ||
InitValueT, | ||
OffsetT, | ||
AccumT, | ||
ForceInclusive>) | ||
|
||
CUB_RUNTIME_FUNCTION static constexpr std::size_t AccumSize() | ||
{ | ||
return sizeof(AccumT); | ||
} | ||
}; | ||
|
||
/****************************************************************************** | ||
* Dispatch | ||
******************************************************************************/ | ||
|
@@ -95,13 +128,23 @@ template <typename InputIteratorT, | |
typename ScanOpT, | ||
typename InitValueT, | ||
typename OffsetT, | ||
typename AccumT = ::cuda::std::__accumulator_t<ScanOpT, | ||
cub::detail::value_t<InputIteratorT>, | ||
::cuda::std::_If<std::is_same<InitValueT, NullType>::value, | ||
cub::detail::value_t<InputIteratorT>, | ||
typename InitValueT::value_type>>, | ||
typename PolicyHub = detail::scan::policy_hub<AccumT, ScanOpT>, | ||
bool ForceInclusive = false> | ||
typename AccumT = ::cuda::std::__accumulator_t<ScanOpT, | ||
cub::detail::value_t<InputIteratorT>, | ||
::cuda::std::_If<std::is_same<InitValueT, NullType>::value, | ||
cub::detail::value_t<InputIteratorT>, | ||
typename InitValueT::value_type>>, | ||
typename PolicyHub = detail::scan::policy_hub<AccumT, ScanOpT>, | ||
bool ForceInclusive = false, | ||
typename KernelSource = DeviceScanKernelSource<typename PolicyHub::MaxPolicy, | ||
InputIteratorT, | ||
OutputIteratorT, | ||
ScanOpT, | ||
InitValueT, | ||
OffsetT, | ||
AccumT, | ||
ForceInclusive>, | ||
typename KernelLauncherFactory = detail::TripleChevronFactory> | ||
|
||
struct DispatchScan | ||
{ | ||
//--------------------------------------------------------------------- | ||
|
@@ -141,6 +184,10 @@ struct DispatchScan | |
|
||
int ptx_version; | ||
|
||
KernelSource kernel_source; | ||
|
||
KernelLauncherFactory launcher_factory; | ||
|
||
/** | ||
* | ||
* @param[in] d_temp_storage | ||
|
@@ -179,7 +226,9 @@ struct DispatchScan | |
ScanOpT scan_op, | ||
InitValueT init_value, | ||
cudaStream_t stream, | ||
int ptx_version) | ||
int ptx_version, | ||
KernelSource kernel_source = {}, | ||
KernelLauncherFactory launcher_factory = {}) | ||
: d_temp_storage(d_temp_storage) | ||
, temp_storage_bytes(temp_storage_bytes) | ||
, d_in(d_in) | ||
|
@@ -189,17 +238,20 @@ struct DispatchScan | |
, num_items(num_items) | ||
, stream(stream) | ||
, ptx_version(ptx_version) | ||
, kernel_source(kernel_source) | ||
, launcher_factory(launcher_factory) | ||
{} | ||
|
||
template <typename ActivePolicyT, typename InitKernel, typename ScanKernel> | ||
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel) | ||
template <typename ActivePolicyT, typename InitKernelT, typename ScanKernelT> | ||
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t | ||
Invoke(InitKernelT init_kernel, ScanKernelT scan_kernel, ActivePolicyT policy = {}) | ||
{ | ||
using Policy = typename ActivePolicyT::ScanPolicyT; | ||
using ScanTileStateT = typename cub::ScanTileState<AccumT>; | ||
using ScanTileStateT = typename KernelSource::ScanTileStateT; | ||
|
||
// TODO(ashwin): Don't know how to handle this. | ||
// `LOAD_LDG` makes in-place execution UB and doesn't lead to better | ||
// performance. | ||
static_assert(Policy::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG, | ||
static_assert(policy.LoadModifier() != CacheLoadModifier::LOAD_LDG, | ||
"The memory consistency model does not apply to texture " | ||
"accesses"); | ||
|
||
|
@@ -215,7 +267,7 @@ struct DispatchScan | |
} | ||
|
||
// Number of input tiles | ||
int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; | ||
int tile_size = policy.Scan().BlockThreads() * policy.Scan().ItemsPerThread(); | ||
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size)); | ||
|
||
// Specify temporary storage allocation requirements | ||
|
@@ -265,8 +317,7 @@ struct DispatchScan | |
#endif // CUB_DETAIL_DEBUG_ENABLE_LOG | ||
|
||
// Invoke init_kernel to initialize tile descriptors | ||
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(init_grid_size, INIT_KERNEL_THREADS, 0, stream) | ||
.doit(init_kernel, tile_state, num_tiles); | ||
launcher_factory(init_grid_size, INIT_KERNEL_THREADS, 0, stream).doit(init_kernel, tile_state, num_tiles); | ||
|
||
// Check for failure to launch | ||
error = CubDebug(cudaPeekAtLastError()); | ||
|
@@ -286,12 +337,13 @@ struct DispatchScan | |
int scan_sm_occupancy; | ||
error = CubDebug(MaxSmOccupancy(scan_sm_occupancy, // out | ||
scan_kernel, | ||
Policy::BLOCK_THREADS)); | ||
policy.Scan().BlockThreads())); | ||
if (cudaSuccess != error) | ||
{ | ||
break; | ||
} | ||
|
||
// TODO(ashwin): should this come from the launcher factory instead? | ||
// Get max x-dimension of grid | ||
int max_dim_x; | ||
error = CubDebug(cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal)); | ||
Comment on lines
+346
to
349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: yes, all CUDA Runtime calls should be consolidated in launcher factory |
||
|
@@ -310,14 +362,14 @@ struct DispatchScan | |
"per thread, %d SM occupancy\n", | ||
start_tile, | ||
scan_grid_size, | ||
Policy::BLOCK_THREADS, | ||
policy.Scan().BlockThreads(), | ||
(long long) stream, | ||
Policy::ITEMS_PER_THREAD, | ||
policy.Scan().ItemsPerThread(), | ||
scan_sm_occupancy); | ||
#endif // CUB_DETAIL_DEBUG_ENABLE_LOG | ||
|
||
// Invoke scan_kernel | ||
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(scan_grid_size, Policy::BLOCK_THREADS, 0, stream) | ||
launcher_factory(scan_grid_size, policy.Scan().BlockThreads(), 0, stream) | ||
.doit(scan_kernel, d_in, d_out, tile_state, start_tile, scan_op, init_value, num_items); | ||
|
||
// Check for failure to launch | ||
|
@@ -340,21 +392,12 @@ struct DispatchScan | |
} | ||
|
||
template <typename ActivePolicyT> | ||
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke() | ||
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT active_policy = {}) | ||
{ | ||
using ScanTileStateT = typename cub::ScanTileState<AccumT>; | ||
using ScanTileStateT = typename KernelSource::ScanTileStateT; | ||
auto wrapped_policy = MakeScanPolicyWrapper(active_policy); | ||
// Ensure kernels are instantiated. | ||
return Invoke<ActivePolicyT>( | ||
DeviceScanInitKernel<ScanTileStateT>, | ||
DeviceScanKernel<typename PolicyHub::MaxPolicy, | ||
InputIteratorT, | ||
OutputIteratorT, | ||
ScanTileStateT, | ||
ScanOpT, | ||
InitValueT, | ||
OffsetT, | ||
AccumT, | ||
ForceInclusive>); | ||
return Invoke(kernel_source.ScanInitKernel(), kernel_source.ScanKernel(), wrapped_policy); | ||
} | ||
|
||
/** | ||
|
@@ -388,6 +431,7 @@ struct DispatchScan | |
* Default is stream<sub>0</sub>. | ||
* | ||
*/ | ||
template <typename MaxPolicyT = typename PolicyHub::MaxPolicy> | ||
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch( | ||
void* d_temp_storage, | ||
size_t& temp_storage_bytes, | ||
|
@@ -396,7 +440,9 @@ struct DispatchScan | |
ScanOpT scan_op, | ||
InitValueT init_value, | ||
OffsetT num_items, | ||
cudaStream_t stream) | ||
cudaStream_t stream, | ||
KernelSource kernel_source = {}, | ||
MaxPolicyT max_policy = {}) | ||
{ | ||
cudaError_t error; | ||
do | ||
|
@@ -411,10 +457,19 @@ struct DispatchScan | |
|
||
// Create dispatch functor | ||
DispatchScan dispatch( | ||
d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, scan_op, init_value, stream, ptx_version); | ||
d_temp_storage, | ||
temp_storage_bytes, | ||
d_in, | ||
d_out, | ||
num_items, | ||
scan_op, | ||
init_value, | ||
stream, | ||
ptx_version, | ||
kernel_source); | ||
|
||
// Dispatch to chained policy | ||
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch)); | ||
error = CubDebug(max_policy.Invoke(ptx_version, dispatch)); | ||
if (cudaSuccess != error) | ||
{ | ||
break; | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -42,6 +42,7 @@ | |||
#include <cub/block/block_load.cuh> | ||||
#include <cub/block/block_scan.cuh> | ||||
#include <cub/block/block_store.cuh> | ||||
#include <cub/thread/thread_load.cuh> | ||||
#include <cub/util_device.cuh> | ||||
#include <cub/util_type.cuh> | ||||
|
||||
|
@@ -229,6 +230,37 @@ struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_si | |||
#endif | ||||
// clang-format on | ||||
|
||||
template <typename PolicyT, typename = void, typename = void> | ||||
struct ScanPolicyWrapper : PolicyT | ||||
{ | ||||
CUB_RUNTIME_FUNCTION ScanPolicyWrapper(PolicyT base) | ||||
: PolicyT(base) | ||||
{} | ||||
}; | ||||
|
||||
template <typename StaticPolicyT> | ||||
struct ScanPolicyWrapper<StaticPolicyT, ::cuda::std::void_t<decltype(StaticPolicyT::ScanPolicy::LOAD_MODIFIER)>> | ||||
: StaticPolicyT | ||||
|
||||
{ | ||||
CUB_RUNTIME_FUNCTION ScanPolicyWrapper(StaticPolicyT base) | ||||
: StaticPolicyT(base) | ||||
{} | ||||
|
||||
CUB_DEFINE_SUB_POLICY_GETTER(Scan) | ||||
|
||||
CUB_RUNTIME_FUNCTION constexpr CacheLoadModifier LoadModifier() | ||||
{ | ||||
return StaticPolicyT::ScanPolicy::LOAD_MODIFIER; | ||||
} | ||||
}; | ||||
|
||||
template <typename PolicyT> | ||||
CUB_RUNTIME_FUNCTION ScanPolicyWrapper<PolicyT> MakeScanPolicyWrapper(PolicyT policy) | ||||
{ | ||||
return ScanPolicyWrapper<PolicyT>{policy}; | ||||
} | ||||
|
||||
template <typename AccumT, typename ScanOpT> | ||||
struct policy_hub | ||||
{ | ||||
|
@@ -242,19 +274,19 @@ struct policy_hub | |||
struct Policy350 : ChainedPolicy<350, Policy350, Policy350> | ||||
{ | ||||
// GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T | ||||
using ScanPolicyT = | ||||
using ScanPolicy = | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The macro There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sooo this is potentially problematic. The macro can be changed, what I'm mildly concerned about here is that this will break users who provide their own policies and spell them We should probably just inline the macro above and make this work with the existing name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, code passing user-defined policy should still work. For instance, this change should brake our scan tuning:
|
||||
AgentScanPolicy<128, 12, AccumT, BLOCK_LOAD_DIRECT, LOAD_CA, BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, BLOCK_SCAN_RAKING>; | ||||
}; | ||||
struct Policy520 : ChainedPolicy<520, Policy520, Policy350> | ||||
{ | ||||
// Titan X: 32.47B items/s @ 48M 32-bit T | ||||
using ScanPolicyT = | ||||
using ScanPolicy = | ||||
AgentScanPolicy<128, 12, AccumT, BLOCK_LOAD_DIRECT, LOAD_CA, scan_transposed_store, BLOCK_SCAN_WARP_SCANS>; | ||||
}; | ||||
|
||||
struct DefaultPolicy | ||||
{ | ||||
using ScanPolicyT = | ||||
using ScanPolicy = | ||||
AgentScanPolicy<128, 15, AccumT, scan_transposed_load, LOAD_DEFAULT, scan_transposed_store, BLOCK_SCAN_WARP_SCANS>; | ||||
}; | ||||
|
||||
|
@@ -276,11 +308,11 @@ struct policy_hub | |||
MemBoundScaling<Tuning::threads, Tuning::items, AccumT>, | ||||
typename Tuning::delay_constructor>; | ||||
template <typename Tuning> | ||||
static auto select_agent_policy(long) -> typename DefaultPolicy::ScanPolicyT; | ||||
static auto select_agent_policy(long) -> typename DefaultPolicy::ScanPolicy; | ||||
|
||||
struct Policy800 : ChainedPolicy<800, Policy800, Policy600> | ||||
{ | ||||
using ScanPolicyT = decltype(select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0)); | ||||
using ScanPolicy = decltype(select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0)); | ||||
}; | ||||
|
||||
struct Policy860 | ||||
|
@@ -290,7 +322,7 @@ struct policy_hub | |||
|
||||
struct Policy900 : ChainedPolicy<900, Policy900, Policy860> | ||||
{ | ||||
using ScanPolicyT = decltype(select_agent_policy<sm90_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0)); | ||||
using ScanPolicy = decltype(select_agent_policy<sm90_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0)); | ||||
}; | ||||
|
||||
using MaxPolicy = Policy900; | ||||
|
@@ -302,4 +334,7 @@ struct policy_hub | |||
template <typename AccumT, typename ScanOpT = ::cuda::std::plus<>> | ||||
using DeviceScanPolicy = detail::scan::policy_hub<AccumT, ScanOpT>; | ||||
|
||||
using detail::scan::MakeScanPolicyWrapper; | ||||
using detail::scan::ScanPolicyWrapper; | ||||
|
||||
CUB_NAMESPACE_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
important: we probably don't want this struct as part of our public API. Let's wrap it into
detail
namespace.