diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index d5b632ef5..276870ed3 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -33,6 +33,29 @@ BEGIN_ROCPRIM_NAMESPACE +template +struct LoadImpl { + ROCPRIM_HOST_DEVICE static T apply(ITType src) { + return *src; + } +}; + +template +struct LoadImpl { + ROCPRIM_HOST_DEVICE static bool apply(ITType src) { + static_assert(sizeof(bool) == sizeof(char), ""); + // Protect against invalid boolean values by loading as a byte + // first, then converting to bool. + return *reinterpret_cast(src); + } +}; + + +template +ROCPRIM_HOST_DEVICE T load(ITType src) { + return LoadImpl::apply(src); +} + /// \class transform_iterator /// \brief A random-access input (read-only) iterator adaptor for transforming dereferenced values. /// @@ -73,6 +96,8 @@ class transform_iterator /// The type of unary function used to transform input range. using unary_function = UnaryFunction; + using deref_type = typename std::iterator_traits::value_type; + #ifndef DOXYGEN_SHOULD_SKIP_THIS using self_type = transform_iterator; #endif @@ -125,7 +150,7 @@ class transform_iterator ROCPRIM_HOST_DEVICE inline value_type operator*() const { - return transform_(*iterator_); + return transform_(load(iterator_)); } ROCPRIM_HOST_DEVICE inline