Skip to content

Commit

Permalink
Add weak_ptr<void> callback_lifetime to SubscriptionOptions
Browse files Browse the repository at this point in the history
Avoid potential use after free usage of a registered
subscription callback function by allowing user to
specify a weak_ptr to be checked for
expiry before the associated subscription callback is called.

If user does not specify callback_lifetime,
the mechanism falls back to a tracking the lifetime
of a user specified callback_group, failing that it
tracks the lifetime of the nodes default_callback_group.

Signed-off-by: Mike Wake <[email protected]>
  • Loading branch information
ewak committed Jan 12, 2025
1 parent 9cabd69 commit d9a56d2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,16 @@ class SubscriptionIntraProcess
rclcpp::Context::SharedPtr context,
const std::string & topic_name,
const rclcpp::QoS & qos_profile,
rclcpp::IntraProcessBufferType buffer_type)
rclcpp::IntraProcessBufferType buffer_type,
std::weak_ptr<void> callback_lifetime)
: SubscriptionIntraProcessBuffer<SubscribedType, SubscribedTypeAlloc,
SubscribedTypeDeleter, ROSMessageType>(
std::make_shared<SubscribedTypeAlloc>(*allocator),
context,
topic_name,
qos_profile,
buffer_type),
callback_lifetime_(callback_lifetime),
any_callback_(callback)
{
TRACETOOLS_TRACEPOINT(
Expand Down Expand Up @@ -166,6 +168,10 @@ class SubscriptionIntraProcess
typename std::enable_if<!std::is_same<T, rcl_serialized_message_t>::value, void>::type
execute_impl(const std::shared_ptr<void> & data)
{
if (callback_lifetime_.expired()) {
return;
}

if (nullptr == data) {
return;
}
Expand All @@ -187,6 +193,7 @@ class SubscriptionIntraProcess
shared_ptr.reset();
}

std::weak_ptr<void> callback_lifetime_;
AnySubscriptionCallback<MessageT, Alloc> any_callback_;
};

Expand Down
69 changes: 52 additions & 17 deletions rclcpp/include/rclcpp/subscription.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ class Subscription : public SubscriptionBase
"' is not allowed with 0 depth qos policy");
}

// Use std::weak_ptr owner_before trick to determine if user
// has assigned a subscription options_.callback_lifetime weak_ptr.
// https://stackoverflow.com/a/45507610
std::weak_ptr<void> empty;
if (!options_.callback_lifetime.owner_before(empty) &&
!empty.owner_before(options_.callback_lifetime)) {
// options_.callback_lifetime was not user assigned,
// So use options_.callback_group if user assigned,
// falling back to node's default_callback_group
std::shared_ptr<void> vsp = options_.callback_group != nullptr ?
options_.callback_group :
node_base->get_default_callback_group();
std::weak_ptr<void> vwp = vsp;
options_.callback_lifetime = vwp;
}

if (options_.callback_lifetime.expired())
{
throw std::invalid_argument(
"callback_lifetime weak_ptr for topic '" + topic_name +
"' has already expired");
}

using SubscriptionIntraProcessT = rclcpp::experimental::SubscriptionIntraProcess<
MessageT,
SubscribedType,
Expand All @@ -172,7 +195,8 @@ class Subscription : public SubscriptionBase
context,
this->get_topic_name(), // important to get like this, as it has the fully-qualified name
qos_profile,
resolve_intra_process_buffer_type(options_.intra_process_buffer_type, callback));
resolve_intra_process_buffer_type(options_.intra_process_buffer_type, callback),
options_.callback_lifetime);
TRACETOOLS_TRACEPOINT(
rclcpp_subscription_init,
static_cast<const void *>(get_subscription_handle().get()),
Expand Down Expand Up @@ -300,12 +324,15 @@ class Subscription : public SubscriptionBase
now = std::chrono::system_clock::now();
}

any_callback_.dispatch(typed_message, message_info);
if (!options_.callback_lifetime.expired())
{
any_callback_.dispatch(typed_message, message_info);

if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
}
}
}

Expand All @@ -321,12 +348,15 @@ class Subscription : public SubscriptionBase
now = std::chrono::system_clock::now();
}

any_callback_.dispatch(serialized_message, message_info);
if (!options_.callback_lifetime.expired())
{
any_callback_.dispatch(serialized_message, message_info);

if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
}
}
}

Expand All @@ -353,12 +383,15 @@ class Subscription : public SubscriptionBase
now = std::chrono::system_clock::now();
}

any_callback_.dispatch(sptr, message_info);
if (!options_.callback_lifetime.expired())
{
any_callback_.dispatch(sptr, message_info);

if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
if (subscription_topic_statistics_) {
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
}
}
}

Expand Down Expand Up @@ -449,7 +482,9 @@ class Subscription : public SubscriptionBase
* It is important to save a copy of this so that the rmw payload which it
* may contain is kept alive for the duration of the subscription.
*/
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> options_;
// NOTE: Had to drop const in order to set default options_.callback_lifetime
// if not set in user code.
rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> options_;
typename message_memory_strategy::MessageMemoryStrategy<ROSMessageType, AllocatorT>::SharedPtr
message_memory_strategy_;

Expand Down
2 changes: 2 additions & 0 deletions rclcpp/include/rclcpp/subscription_options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct SubscriptionOptionsBase
QosOverridingOptions qos_overriding_options;

ContentFilterOptions content_filter_options;

std::weak_ptr<void> callback_lifetime;
};

/// Structure containing optional configuration for Subscriptions.
Expand Down

0 comments on commit d9a56d2

Please sign in to comment.