Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic CUB dispatch for scan to support c.parallel #3398

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 91 additions & 36 deletions cub/cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#endif // no system header

#include <cub/agent/agent_scan.cuh>
#include <cub/detail/launcher/cuda_runtime.cuh>
#include <cub/device/dispatch/kernels/scan.cuh>
#include <cub/device/dispatch/tuning/tuning_scan.cuh>
#include <cub/grid/grid_queue.cuh>
Expand All @@ -62,6 +63,38 @@

CUB_NAMESPACE_BEGIN

template <typename MaxPolicyT,
typename InputIteratorT,
typename OutputIteratorT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT,
bool ForceInclusive>
struct DeviceScanKernelSource
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important: we probably don't want this struct as part of our public API. Let's wrap it into detail namespace.

{
using ScanTileStateT = typename cub::ScanTileState<AccumT>;

CUB_DEFINE_KERNEL_GETTER(ScanInitKernel, DeviceScanInitKernel<ScanTileStateT>)

CUB_DEFINE_KERNEL_GETTER(
ScanKernel,
DeviceScanKernel<MaxPolicyT,
InputIteratorT,
OutputIteratorT,
ScanTileStateT,
ScanOpT,
InitValueT,
OffsetT,
AccumT,
ForceInclusive>)

CUB_RUNTIME_FUNCTION static constexpr std::size_t AccumSize()
{
return sizeof(AccumT);
}
};

/******************************************************************************
* Dispatch
******************************************************************************/
Expand Down Expand Up @@ -95,13 +128,23 @@ template <typename InputIteratorT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT = ::cuda::std::__accumulator_t<ScanOpT,
cub::detail::value_t<InputIteratorT>,
::cuda::std::_If<std::is_same<InitValueT, NullType>::value,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>>,
typename PolicyHub = detail::scan::policy_hub<AccumT, ScanOpT>,
bool ForceInclusive = false>
typename AccumT = ::cuda::std::__accumulator_t<ScanOpT,
cub::detail::value_t<InputIteratorT>,
::cuda::std::_If<std::is_same<InitValueT, NullType>::value,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>>,
typename PolicyHub = detail::scan::policy_hub<AccumT, ScanOpT>,
bool ForceInclusive = false,
typename KernelSource = DeviceScanKernelSource<typename PolicyHub::MaxPolicy,
InputIteratorT,
OutputIteratorT,
ScanOpT,
InitValueT,
OffsetT,
AccumT,
ForceInclusive>,
typename KernelLauncherFactory = detail::TripleChevronFactory>

struct DispatchScan
{
//---------------------------------------------------------------------
Expand Down Expand Up @@ -141,6 +184,10 @@ struct DispatchScan

int ptx_version;

KernelSource kernel_source;

KernelLauncherFactory launcher_factory;

/**
*
* @param[in] d_temp_storage
Expand Down Expand Up @@ -179,7 +226,9 @@ struct DispatchScan
ScanOpT scan_op,
InitValueT init_value,
cudaStream_t stream,
int ptx_version)
int ptx_version,
KernelSource kernel_source = {},
KernelLauncherFactory launcher_factory = {})
: d_temp_storage(d_temp_storage)
, temp_storage_bytes(temp_storage_bytes)
, d_in(d_in)
Expand All @@ -189,17 +238,20 @@ struct DispatchScan
, num_items(num_items)
, stream(stream)
, ptx_version(ptx_version)
, kernel_source(kernel_source)
, launcher_factory(launcher_factory)
{}

template <typename ActivePolicyT, typename InitKernel, typename ScanKernel>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel)
template <typename ActivePolicyT, typename InitKernelT, typename ScanKernelT>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t
Invoke(InitKernelT init_kernel, ScanKernelT scan_kernel, ActivePolicyT policy = {})
{
using Policy = typename ActivePolicyT::ScanPolicyT;
using ScanTileStateT = typename cub::ScanTileState<AccumT>;
using ScanTileStateT = typename KernelSource::ScanTileStateT;

// TODO(ashwin): Don't know how to handle this.
// `LOAD_LDG` makes in-place execution UB and doesn't lead to better
// performance.
static_assert(Policy::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG,
static_assert(policy.LoadModifier() != CacheLoadModifier::LOAD_LDG,
"The memory consistency model does not apply to texture "
"accesses");

Expand All @@ -215,7 +267,7 @@ struct DispatchScan
}

// Number of input tiles
int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD;
int tile_size = policy.Scan().BlockThreads() * policy.Scan().ItemsPerThread();
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size));

// Specify temporary storage allocation requirements
Expand Down Expand Up @@ -265,8 +317,7 @@ struct DispatchScan
#endif // CUB_DETAIL_DEBUG_ENABLE_LOG

// Invoke init_kernel to initialize tile descriptors
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(init_grid_size, INIT_KERNEL_THREADS, 0, stream)
.doit(init_kernel, tile_state, num_tiles);
launcher_factory(init_grid_size, INIT_KERNEL_THREADS, 0, stream).doit(init_kernel, tile_state, num_tiles);

// Check for failure to launch
error = CubDebug(cudaPeekAtLastError());
Expand All @@ -286,12 +337,13 @@ struct DispatchScan
int scan_sm_occupancy;
error = CubDebug(MaxSmOccupancy(scan_sm_occupancy, // out
scan_kernel,
Policy::BLOCK_THREADS));
policy.Scan().BlockThreads()));
if (cudaSuccess != error)
{
break;
}

// TODO(ashwin): should this come from the launcher factory instead?
// Get max x-dimension of grid
int max_dim_x;
error = CubDebug(cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal));
Comment on lines +346 to 349
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: yes, all CUDA Runtime calls should be consolidated in launcher factory

Expand All @@ -310,14 +362,14 @@ struct DispatchScan
"per thread, %d SM occupancy\n",
start_tile,
scan_grid_size,
Policy::BLOCK_THREADS,
policy.Scan().BlockThreads(),
(long long) stream,
Policy::ITEMS_PER_THREAD,
policy.Scan().ItemsPerThread(),
scan_sm_occupancy);
#endif // CUB_DETAIL_DEBUG_ENABLE_LOG

// Invoke scan_kernel
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(scan_grid_size, Policy::BLOCK_THREADS, 0, stream)
launcher_factory(scan_grid_size, policy.Scan().BlockThreads(), 0, stream)
.doit(scan_kernel, d_in, d_out, tile_state, start_tile, scan_op, init_value, num_items);

// Check for failure to launch
Expand All @@ -340,21 +392,12 @@ struct DispatchScan
}

template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke()
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT active_policy = {})
{
using ScanTileStateT = typename cub::ScanTileState<AccumT>;
using ScanTileStateT = typename KernelSource::ScanTileStateT;
auto wrapped_policy = MakeScanPolicyWrapper(active_policy);
// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
DeviceScanInitKernel<ScanTileStateT>,
DeviceScanKernel<typename PolicyHub::MaxPolicy,
InputIteratorT,
OutputIteratorT,
ScanTileStateT,
ScanOpT,
InitValueT,
OffsetT,
AccumT,
ForceInclusive>);
return Invoke(kernel_source.ScanInitKernel(), kernel_source.ScanKernel(), wrapped_policy);
}

/**
Expand Down Expand Up @@ -388,6 +431,7 @@ struct DispatchScan
* Default is stream<sub>0</sub>.
*
*/
template <typename MaxPolicyT = typename PolicyHub::MaxPolicy>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -396,7 +440,9 @@ struct DispatchScan
ScanOpT scan_op,
InitValueT init_value,
OffsetT num_items,
cudaStream_t stream)
cudaStream_t stream,
KernelSource kernel_source = {},
MaxPolicyT max_policy = {})
{
cudaError_t error;
do
Expand All @@ -411,10 +457,19 @@ struct DispatchScan

// Create dispatch functor
DispatchScan dispatch(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, scan_op, init_value, stream, ptx_version);
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
scan_op,
init_value,
stream,
ptx_version,
kernel_source);

// Dispatch to chained policy
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
error = CubDebug(max_policy.Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
break;
Expand Down
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/kernels/scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ template <typename ChainedPolicyT,
typename OffsetT,
typename AccumT,
bool ForceInclusive>
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel(
InputIteratorT d_in,
OutputIteratorT d_out,
Expand All @@ -166,7 +166,7 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
OffsetT num_items)
{
using RealInitValueT = typename InitValueT::value_type;
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicyT;
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicy;

// Thread block type for scanning input tiles
using AgentScanT =
Expand Down
47 changes: 41 additions & 6 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/thread/thread_load.cuh>
#include <cub/util_device.cuh>
#include <cub/util_type.cuh>

Expand Down Expand Up @@ -229,6 +230,37 @@ struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_si
#endif
// clang-format on

template <typename PolicyT, typename = void, typename = void>
struct ScanPolicyWrapper : PolicyT
{
CUB_RUNTIME_FUNCTION ScanPolicyWrapper(PolicyT base)
: PolicyT(base)
{}
};

template <typename StaticPolicyT>
struct ScanPolicyWrapper<StaticPolicyT, ::cuda::std::void_t<decltype(StaticPolicyT::ScanPolicy::LOAD_MODIFIER)>>
: StaticPolicyT

{
CUB_RUNTIME_FUNCTION ScanPolicyWrapper(StaticPolicyT base)
: StaticPolicyT(base)
{}

CUB_DEFINE_SUB_POLICY_GETTER(Scan)

CUB_RUNTIME_FUNCTION constexpr CacheLoadModifier LoadModifier()
{
return StaticPolicyT::ScanPolicy::LOAD_MODIFIER;
}
};

template <typename PolicyT>
CUB_RUNTIME_FUNCTION ScanPolicyWrapper<PolicyT> MakeScanPolicyWrapper(PolicyT policy)
{
return ScanPolicyWrapper<PolicyT>{policy};
}

template <typename AccumT, typename ScanOpT>
struct policy_hub
{
Expand All @@ -242,19 +274,19 @@ struct policy_hub
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
{
// GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T
using ScanPolicyT =
using ScanPolicy =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macro CUB_DEFINE_SUB_POLICY_GETTER makes some assumptions about the name of the policy so this rename is necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sooo this is potentially problematic. The macro can be changed, what I'm mildly concerned about here is that this will break users who provide their own policies and spell them ScanPolicyT, and that stops working with this PR. cc @gevtushenko

We should probably just inline the macro above and make this work with the existing name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, code passing user-defined policy should still work. For instance, this change should brake our scan tuning:

, we just don't build it as part of CI

AgentScanPolicy<128, 12, AccumT, BLOCK_LOAD_DIRECT, LOAD_CA, BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, BLOCK_SCAN_RAKING>;
};
struct Policy520 : ChainedPolicy<520, Policy520, Policy350>
{
// Titan X: 32.47B items/s @ 48M 32-bit T
using ScanPolicyT =
using ScanPolicy =
AgentScanPolicy<128, 12, AccumT, BLOCK_LOAD_DIRECT, LOAD_CA, scan_transposed_store, BLOCK_SCAN_WARP_SCANS>;
};

struct DefaultPolicy
{
using ScanPolicyT =
using ScanPolicy =
AgentScanPolicy<128, 15, AccumT, scan_transposed_load, LOAD_DEFAULT, scan_transposed_store, BLOCK_SCAN_WARP_SCANS>;
};

Expand All @@ -276,11 +308,11 @@ struct policy_hub
MemBoundScaling<Tuning::threads, Tuning::items, AccumT>,
typename Tuning::delay_constructor>;
template <typename Tuning>
static auto select_agent_policy(long) -> typename DefaultPolicy::ScanPolicyT;
static auto select_agent_policy(long) -> typename DefaultPolicy::ScanPolicy;

struct Policy800 : ChainedPolicy<800, Policy800, Policy600>
{
using ScanPolicyT = decltype(select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0));
using ScanPolicy = decltype(select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0));
};

struct Policy860
Expand All @@ -290,7 +322,7 @@ struct policy_hub

struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using ScanPolicyT = decltype(select_agent_policy<sm90_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0));
using ScanPolicy = decltype(select_agent_policy<sm90_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0));
};

using MaxPolicy = Policy900;
Expand All @@ -302,4 +334,7 @@ struct policy_hub
template <typename AccumT, typename ScanOpT = ::cuda::std::plus<>>
using DeviceScanPolicy = detail::scan::policy_hub<AccumT, ScanOpT>;

using detail::scan::MakeScanPolicyWrapper;
using detail::scan::ScanPolicyWrapper;

CUB_NAMESPACE_END
Loading