From dc69e5215c1d4d1842c21db8abf395d5cf0c0f6c Mon Sep 17 00:00:00 2001 From: Max Schmeller <6088931+mojomex@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:51:46 +0900 Subject: [PATCH] feat(nebula_hw_interfaces): better UDP socket (#231) * feat(udp): a new UDP socket implementation Signed-off-by: Max SCHMELLER * chore(udp): more error handling and doc comments Signed-off-by: Max SCHMELLER * chore(udp): always enable socket reuse Signed-off-by: Max SCHMELLER * chore(udp): clean up C-style code Signed-off-by: Max SCHMELLER * chore(udp): remove unnecessary double parens Signed-off-by: Max SCHMELLER * feat(udp): use poll to prevent blocking when there is no received data Signed-off-by: Max SCHMELLER * chore(udp): differentiate between socket and usage-related errors Signed-off-by: Max SCHMELLER * feat(udp): allow setting receive buffer size Signed-off-by: Max SCHMELLER * chore(udp): use uint8_t because std::byte is annoying to refactor into everywhere Signed-off-by: Max SCHMELLER * fix(udp): update state correctly when `bind()` is called Signed-off-by: Max SCHMELLER * feat(udp): monitor socket packet drops Signed-off-by: Max SCHMELLER * feat(udp): add explicit unsubscribe function to facilitate clean shutdown Signed-off-by: Max SCHMELLER * chore(expected): add stdexcept include Signed-off-by: Max SCHMELLER * feat(udp): report when messages have been truncated Signed-off-by: Max SCHMELLER * chore(udp): relax some usage requirements Signed-off-by: Max SCHMELLER * test(udp): add most of the unit tests for udp socket Signed-off-by: Max SCHMELLER * ci(pre-commit): autofix * chore(cspell): add OVFL to dictionary Signed-off-by: Max SCHMELLER * fix(udp): return correctly truncated buffer when oversized packet is received Signed-off-by: Max SCHMELLER * chore(udp): uniform initialization for buffer_size_ Signed-off-by: Max SCHMELLER * chore(udp): disallow re-initializing socket Signed-off-by: Max SCHMELLER * chore(udp): make error value checking consistent (== -1) Signed-off-by: Max SCHMELLER * chore(udp): add explanatory comment on handling of 0-length datagrams Signed-off-by: Max SCHMELLER * chore(udp): disallow binding to broadcast IP Signed-off-by: Max SCHMELLER * chore(udp): bind to host IP instead of INADDR_ANY in non-multicast case Signed-off-by: Max SCHMELLER * feat(udp): make polling interval configurable Signed-off-by: Max SCHMELLER * chore(udp): rename ReceiveMetadata to RxMetadata Signed-off-by: Max SCHMELLER * chore(udp): parse IP addresses with error checking Signed-off-by: Max SCHMELLER * test(udp): update and fix tests after changes of recent commits Signed-off-by: Max SCHMELLER * feat(expected): add shorthand `value_or_throw()` function Signed-off-by: Max SCHMELLER * feat(udp): refactor to typestate-based builder pattern to make misuse a compiler error Signed-off-by: Max SCHMELLER * ci(pre-commit): autofix * chore(udp): replace `-1` with `uninitialized` in `SockFd` for clarity Signed-off-by: Max SCHMELLER * chore(udp): mark SockFd constructor explicit Signed-off-by: Max SCHMELLER * fix(udp): enforce that at most one multicast group is joined Signed-off-by: Max SCHMELLER * chore(udp): reorder UdpSocket class to reduce number of needed visibility modifiers Signed-off-by: Max SCHMELLER * fix(udp): make UDP socket object move-constructible to enable things like optional.emplace() etc. Signed-off-by: Max SCHMELLER --------- Signed-off-by: Max SCHMELLER Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .cspell.json | 1 + .../include/nebula_common/util/expected.hpp | 8 + nebula_hw_interfaces/CMakeLists.txt | 21 +- .../connections/udp.hpp | 458 ++++++++++++++++++ nebula_hw_interfaces/test/common/test_udp.cpp | 210 ++++++++ .../test/common/test_udp/utils.hpp | 77 +++ 6 files changed, 769 insertions(+), 6 deletions(-) create mode 100644 nebula_hw_interfaces/include/nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp create mode 100644 nebula_hw_interfaces/test/common/test_udp.cpp create mode 100644 nebula_hw_interfaces/test/common/test_udp/utils.hpp diff --git a/.cspell.json b/.cspell.json index 4c14cb9f3..f305225f6 100644 --- a/.cspell.json +++ b/.cspell.json @@ -32,6 +32,7 @@ "nproc", "nsec", "ntoa", + "OVFL", "pandar", "PANDAR", "PANDARAT", diff --git a/nebula_common/include/nebula_common/util/expected.hpp b/nebula_common/include/nebula_common/util/expected.hpp index 1d4333443..cd6061716 100644 --- a/nebula_common/include/nebula_common/util/expected.hpp +++ b/nebula_common/include/nebula_common/util/expected.hpp @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -78,6 +79,13 @@ struct expected throw std::runtime_error(error_msg); } + /// @brief If the instance has a value, return the value, else throw the stored error instance. + T value_or_throw() + { + if (has_value()) return value(); + throw error(); + } + /// @brief Retrieve the error, or throw `bad_expected_access` if a value is contained. /// @return The error of type `E` E error() diff --git a/nebula_hw_interfaces/CMakeLists.txt b/nebula_hw_interfaces/CMakeLists.txt index 2577ea9a5..f5dc6569b 100644 --- a/nebula_hw_interfaces/CMakeLists.txt +++ b/nebula_hw_interfaces/CMakeLists.txt @@ -2,13 +2,13 @@ cmake_minimum_required(VERSION 3.14) project(nebula_hw_interfaces) # Default to C++17 -if (NOT CMAKE_CXX_STANDARD) +if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) -endif () +endif() -if (CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") add_compile_options(-Wall -Wextra -Wpedantic -Wunused-function) -endif () +endif() find_package(ament_cmake_auto REQUIRED) find_package(boost_tcp_driver) @@ -53,7 +53,6 @@ target_link_libraries(nebula_hw_interfaces_velodyne PUBLIC ${boost_tcp_driver_LIBRARIES} ${boost_udp_driver_LIBRARIES} ${velodyne_msgs_TARGETS} - ) target_include_directories(nebula_hw_interfaces_velodyne PUBLIC ${boost_udp_driver_INCLUDE_DIRS} @@ -68,7 +67,6 @@ target_link_libraries(nebula_hw_interfaces_robosense PUBLIC ${boost_tcp_driver_LIBRARIES} ${boost_udp_driver_LIBRARIES} ${robosense_msgs_TARGETS} - ) target_include_directories(nebula_hw_interfaces_robosense PUBLIC ${boost_udp_driver_INCLUDE_DIRS} @@ -100,6 +98,17 @@ install(DIRECTORY include/ DESTINATION include/${PROJECT_NAME}) if(BUILD_TESTING) find_package(ament_lint_auto REQUIRED) ament_lint_auto_find_test_dependencies() + + find_package(ament_cmake_gtest REQUIRED) + + ament_add_gtest(test_udp + test/common/test_udp.cpp + ) + + target_include_directories(test_udp PUBLIC + ${nebula_common_INCLUDE_DIRS} + include + test) endif() ament_export_include_directories("include/${PROJECT_NAME}") diff --git a/nebula_hw_interfaces/include/nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp b/nebula_hw_interfaces/include/nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp new file mode 100644 index 000000000..77f91eccc --- /dev/null +++ b/nebula_hw_interfaces/include/nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp @@ -0,0 +1,458 @@ +// Copyright 2024 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifndef _GNU_SOURCE +// See `man strerror_r` +#define _GNU_SOURCE +#endif + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nebula::drivers::connections +{ + +class SocketError : public std::exception +{ + static constexpr size_t gnu_max_strerror_length = 1024; + +public: + explicit SocketError(int err_no) + { + std::array msg_buf; + std::string_view msg = strerror_r(err_no, msg_buf.data(), msg_buf.size()); + what_ = std::string{msg}; + } + + explicit SocketError(const std::string_view & msg) : what_(msg) {} + + const char * what() const noexcept override { return what_.c_str(); } + +private: + std::string what_; +}; + +class UsageError : public std::runtime_error +{ +public: + explicit UsageError(const std::string & msg) : std::runtime_error(msg) {} +}; + +class UdpSocket +{ + struct Endpoint + { + in_addr ip; + uint16_t port; + }; + + class SockFd + { + static const int uninitialized = -1; + int sock_fd_; + + public: + SockFd() : sock_fd_{uninitialized} {} + explicit SockFd(int sock_fd) : sock_fd_{sock_fd} {} + SockFd(SockFd && other) noexcept : sock_fd_{other.sock_fd_} { other.sock_fd_ = uninitialized; } + + SockFd(const SockFd &) = delete; + SockFd & operator=(const SockFd &) = delete; + SockFd & operator=(SockFd && other) + { + std::swap(sock_fd_, other.sock_fd_); + return *this; + }; + + ~SockFd() + { + if (sock_fd_ == uninitialized) return; + close(sock_fd_); + } + + [[nodiscard]] int get() const { return sock_fd_; } + + template + [[nodiscard]] util::expected setsockopt( + int level, int optname, const T & optval) + { + int result = ::setsockopt(sock_fd_, level, optname, &optval, sizeof(T)); + if (result == -1) return SocketError(errno); + return std::monostate{}; + } + }; + + struct SocketConfig + { + int32_t polling_interval_ms{10}; + + size_t buffer_size{1500}; + Endpoint host; + std::optional multicast_ip; + std::optional sender; + }; + + struct MsgBuffers + { + msghdr msg{}; + iovec iov{}; + std::array control; + sockaddr_in sender_addr; + }; + + class DropMonitor + { + uint32_t last_drop_counter_{0}; + + public: + uint32_t get_drops_since_last_receive(uint32_t current_drop_counter) + { + uint32_t last = last_drop_counter_; + last_drop_counter_ = current_drop_counter; + + bool counter_did_wrap = current_drop_counter < last; + if (counter_did_wrap) { + return (UINT32_MAX - last) + current_drop_counter; + } + + return current_drop_counter - last; + } + }; + + UdpSocket(SockFd sock_fd, SocketConfig config) + : sock_fd_(std::move(sock_fd)), poll_fd_{sock_fd_.get(), POLLIN, 0}, config_{std::move(config)} + { + } + +public: + class Builder + { + public: + /** + * @brief Build a UDP socket with timestamp measuring enabled. The minimal way to start + * receiving on the socket is `UdpSocket::Builder(...).bind().subscribe(...);`. + * + * @param host_ip The address to bind to. + * @param host_port The port to bind to. + */ + Builder(const std::string & host_ip, uint16_t host_port) + { + in_addr host_in_addr = parse_ip_or_throw(host_ip); + if (host_in_addr.s_addr == INADDR_BROADCAST) + throw UsageError("Do not bind to broadcast IP. Bind to 0.0.0.0 or a specific IP instead."); + + config_.host = {host_in_addr, host_port}; + + int sock_fd = socket(AF_INET, SOCK_DGRAM, 0); + if (sock_fd == -1) throw SocketError(errno); + sock_fd_ = SockFd{sock_fd}; + + sock_fd_.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1).value_or_throw(); + + // Enable kernel-space receive time measurement + sock_fd_.setsockopt(SOL_SOCKET, SO_TIMESTAMP, 1).value_or_throw(); + + // Enable reporting on packets dropped due to full UDP receive buffer + sock_fd_.setsockopt(SOL_SOCKET, SO_RXQ_OVFL, 1).value_or_throw(); + } + + /** + * @brief Set the socket to drop all packets not coming from `sender_ip` and `sender_port`. + * + * @param sender_ip The only allowed sender IP. Cannot be a multicast or broadcast address. + * @param sender_port The only allowed sender port. + */ + Builder && limit_to_sender(const std::string & sender_ip, uint16_t sender_port) + { + config_.sender.emplace(Endpoint{parse_ip_or_throw(sender_ip), sender_port}); + return std::move(*this); + } + + /** + * @brief Set the MTU this socket supports. While this can be set arbitrarily, it is best set to + * the MTU of the network interface, or to the maximum expected packet length. + * + * @param bytes The MTU size. The default value is 1500. + */ + Builder && set_mtu(size_t bytes) + { + config_.buffer_size = bytes; + return std::move(*this); + } + + /** + * @brief Set the internal socket receive buffer size. See `SO_RCVBUF` in `man 7 socket` for + * more information. + * + * @param bytes The desired buffer size in bytes. + */ + Builder && set_socket_buffer_size(size_t bytes) + { + if (bytes > static_cast(INT32_MAX)) + throw UsageError("The maximum value supported (0x7FFFFFF) has been exceeded"); + + auto buf_size = static_cast(bytes); + sock_fd_.setsockopt(SOL_SOCKET, SO_RCVBUF, buf_size).value_or_throw(); + return std::move(*this); + } + + /** + * @brief Join an IP multicast group. Only one group can be joined by the socket. + * + * @param group_ip The multicast IP. It has to be in the multicast range `224.0.0.0/4` (between + * `224.0.0.0` and `239.255.255.255`). + */ + Builder && join_multicast_group(const std::string & group_ip) + { + if (config_.multicast_ip) + throw UsageError("Only one multicast group can be joined by this socket"); + ip_mreq mreq{parse_ip_or_throw(group_ip), config_.host.ip}; + + sock_fd_.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq).value_or_throw(); + config_.multicast_ip.emplace(mreq.imr_multiaddr); + return std::move(*this); + } + + /** + * @brief Set the interval at which the socket polls for new data. THis should be longer than + * the expected interval of packets arriving in order to not poll unnecessarily often, and + * should be shorter than the acceptable time delay for `unsubscribe()`. The `unsubscribe()` + * function blocks up to one full poll interval before returning. + * + * @param interval_ms The desired polling interval. See `man poll` for the meanings of 0 and + * negative values. + */ + Builder && set_polling_interval(int32_t interval_ms) + { + config_.polling_interval_ms = interval_ms; + return std::move(*this); + } + + /** + * @brief Bind the socket to host IP and port given in `init()`. If `join_multicast_group()` was + * called before this function, the socket will be bound to `group_ip` instead. At least + * `init()` has to have been called before. + */ + UdpSocket bind() && + { + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(config_.host.port); + addr.sin_addr = config_.multicast_ip ? *config_.multicast_ip : config_.host.ip; + + int result = ::bind(sock_fd_.get(), (struct sockaddr *)&addr, sizeof(addr)); + if (result == -1) throw SocketError(errno); + + return UdpSocket{std::move(sock_fd_), config_}; + } + + private: + SockFd sock_fd_; + SocketConfig config_; + }; + + struct RxMetadata + { + std::optional timestamp_ns; + uint64_t drops_since_last_receive{0}; + bool truncated; + }; + + using callback_t = std::function &, const RxMetadata &)>; + + /** + * @brief Register a callback for processing received packets and start the receiver thread. The + * callback will be called for each received packet, and will be executed in the receive thread. + * Has to be called on a bound socket (`bind()` has to have been called before). + * + * @param callback The function to be executed for each received packet. + */ + UdpSocket & subscribe(callback_t && callback) + { + unsubscribe(); + callback_ = std::move(callback); + launch_receiver(); + return *this; + } + + bool is_subscribed() { return running_; } + + /** + * @brief Gracefully stops the active receiver thread (if any) but keeps the socket alive. The + * same socket can later be subscribed again. + */ + UdpSocket & unsubscribe() + { + running_ = false; + if (receive_thread_.joinable()) { + receive_thread_.join(); + } + return *this; + } + + UdpSocket(const UdpSocket &) = delete; + UdpSocket(UdpSocket && other) + : sock_fd_((other.unsubscribe(), std::move(other.sock_fd_))), + poll_fd_(std::move(other.poll_fd_)), + config_(std::move(other.config_)), + drop_monitor_(std::move(other.drop_monitor_)) + { + if (other.callback_) subscribe(std::move(other.callback_)); + }; + + UdpSocket & operator=(const UdpSocket &) = delete; + UdpSocket & operator=(UdpSocket &&) = delete; + + ~UdpSocket() { unsubscribe(); } + +private: + void launch_receiver() + { + assert(callback_); + + running_ = true; + receive_thread_ = std::thread([this]() { + std::vector buffer; + while (running_) { + auto data_available = is_data_available(); + if (!data_available.has_value()) throw SocketError(data_available.error()); + if (!data_available.value()) continue; + + buffer.resize(config_.buffer_size); + auto msg_header = make_msg_header(buffer); + + // As per `man recvmsg`, zero-length datagrams are permitted and valid. Since the socket is + // blocking, a recv_result of 0 means we received a valid 0-length datagram. + ssize_t recv_result = recvmsg(sock_fd_.get(), &msg_header.msg, MSG_TRUNC); + if (recv_result < 0) throw SocketError(errno); + size_t untruncated_packet_length = recv_result; + + if (!is_accepted_sender(msg_header.sender_addr)) continue; + + RxMetadata metadata; + get_receive_metadata(msg_header.msg, metadata); + metadata.truncated = untruncated_packet_length > config_.buffer_size; + + buffer.resize(std::min(config_.buffer_size, untruncated_packet_length)); + callback_(buffer, metadata); + } + }); + } + + void get_receive_metadata(msghdr & msg, RxMetadata & inout_metadata) + { + for (cmsghdr * cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET) continue; + + switch (cmsg->cmsg_type) { + case SO_TIMESTAMP: { + auto tv = (timeval const *)CMSG_DATA(cmsg); + uint64_t timestamp_ns = tv->tv_sec * 1'000'000'000 + tv->tv_usec * 1000; + inout_metadata.timestamp_ns.emplace(timestamp_ns); + break; + } + case SO_RXQ_OVFL: { + auto drops = (uint32_t const *)CMSG_DATA(cmsg); + inout_metadata.drops_since_last_receive = + drop_monitor_.get_drops_since_last_receive(*drops); + break; + } + default: + continue; + } + } + } + + util::expected is_data_available() + { + int status = poll(&poll_fd_, 1, config_.polling_interval_ms); + if (status == -1) return errno; + return (poll_fd_.revents & POLLIN) && (status > 0); + } + + bool is_accepted_sender(const sockaddr_in & sender_addr) + { + if (!config_.sender) return true; + return sender_addr.sin_addr.s_addr == config_.sender->ip.s_addr; + } + + static MsgBuffers make_msg_header(std::vector & receive_buffer) + { + msghdr msg{}; + iovec iov{}; + std::array control; + + sockaddr_in sender_addr; + socklen_t sender_addr_len = sizeof(sender_addr); + + iov.iov_base = receive_buffer.data(); + iov.iov_len = receive_buffer.size(); + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = control.data(); + msg.msg_controllen = control.size(); + msg.msg_name = &sender_addr; + msg.msg_namelen = sender_addr_len; + + return MsgBuffers{msg, iov, control, sender_addr}; + } + + static in_addr parse_ip_or_throw(const std::string & ip) + { + in_addr parsed_addr; + bool valid = inet_aton(ip.c_str(), &parsed_addr); + if (!valid) throw UsageError("Invalid IP address given"); + return parsed_addr; + } + + SockFd sock_fd_; + pollfd poll_fd_; + + SocketConfig config_; + + std::atomic_bool running_{false}; + std::thread receive_thread_; + callback_t callback_; + + DropMonitor drop_monitor_; +}; + +} // namespace nebula::drivers::connections diff --git a/nebula_hw_interfaces/test/common/test_udp.cpp b/nebula_hw_interfaces/test/common/test_udp.cpp new file mode 100644 index 000000000..1c800878e --- /dev/null +++ b/nebula_hw_interfaces/test/common/test_udp.cpp @@ -0,0 +1,210 @@ +// Copyright 2024 TIER IV, Inc. + +#include "common/test_udp/utils.hpp" +#include "nebula_common/util/expected.hpp" +#include "nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nebula::drivers::connections +{ + +using std::chrono_literals::operator""ms; + +static const char localhost_ip[] = "127.0.0.1"; +static const char broadcast_ip[] = "255.255.255.255"; +static const char any_ip[] = "0.0.0.0"; +static const char multicast_group[] = "230.1.2.3"; +static const char multicast_group2[] = "230.4.5.6"; + +static const char sender_ip[] = "192.168.201"; +static const uint16_t sender_port = 7373; +static const uint16_t host_port = 6262; + +static const std::chrono::duration send_receive_timeout = 100ms; + +UdpSocket::callback_t empty_cb() +{ + return [](const auto &, const auto &) {}; +} + +util::expected read_sys_param(const std::string & param_fqn) +{ + std::string path = "/proc/sys/" + param_fqn; + std::replace(path.begin(), path.end(), '.', '/'); + std::ifstream ifs{path}; + if (!ifs) return "could not read " + param_fqn; + + size_t param{}; + if (!(ifs >> param)) return param_fqn + " has unrecognized format"; + return param; +} + +TEST(test_udp, test_basic_lifecycle) +{ + ASSERT_NO_THROW( + UdpSocket::Builder(localhost_ip, host_port).bind().subscribe(empty_cb()).unsubscribe()); +} + +TEST(test_udp, test_special_addresses_bind) +{ + ASSERT_THROW(UdpSocket::Builder(broadcast_ip, host_port), UsageError); + ASSERT_NO_THROW(UdpSocket::Builder(any_ip, host_port).bind()); +} + +TEST(test_udp, test_joining_invalid_multicast_group) +{ + ASSERT_THROW( + UdpSocket::Builder(localhost_ip, host_port).join_multicast_group(broadcast_ip).bind(), + SocketError); +} + +TEST(test_udp, test_buffer_resize) +{ + auto rmem_max_maybe = read_sys_param("net.core.rmem_max"); + if (!rmem_max_maybe.has_value()) GTEST_SKIP() << rmem_max_maybe.error(); + size_t rmem_max = rmem_max_maybe.value(); + + // Setting buffer sizes up to and including rmem_max shall succeed + ASSERT_NO_THROW( + UdpSocket::Builder(localhost_ip, host_port).set_socket_buffer_size(rmem_max).bind()); + + // Linux only supports sizes up to INT32_MAX + ASSERT_THROW( + UdpSocket::Builder(localhost_ip, host_port) + .set_socket_buffer_size(static_cast(INT32_MAX) + 1) + .bind(), + UsageError); +} + +TEST(test_udp, test_correct_usage_is_enforced) +{ + // The following functions can be called in any order, any number of times + ASSERT_NO_THROW(UdpSocket::Builder(localhost_ip, host_port) + .set_polling_interval(20) + .set_socket_buffer_size(3000) + .set_mtu(1600) + .limit_to_sender(sender_ip, sender_port) + .set_polling_interval(20) + .set_socket_buffer_size(3000) + .set_mtu(1600) + .limit_to_sender(sender_ip, sender_port) + .bind()); + + // Only one multicast group can be joined + ASSERT_THROW( + UdpSocket::Builder(localhost_ip, host_port) + .join_multicast_group(multicast_group) + .join_multicast_group(multicast_group2), + UsageError); + + // Pre-existing subscriptions shall be gracefully unsubscribed when a new subscription is created + ASSERT_NO_THROW( + UdpSocket::Builder(localhost_ip, host_port).bind().subscribe(empty_cb()).subscribe(empty_cb())); + + // Explicitly unsubscribing shall be supported + ASSERT_NO_THROW(UdpSocket::Builder(localhost_ip, host_port) + .bind() + .subscribe(empty_cb()) + .unsubscribe() + .subscribe(empty_cb())); + + // Unsubscribing on a non-subscribed socket shall also be supported + ASSERT_NO_THROW(UdpSocket::Builder(localhost_ip, host_port).bind().unsubscribe()); +} + +TEST(test_udp, test_receiving) +{ + const std::vector payload{1, 2, 3}; + auto sock = UdpSocket::Builder(localhost_ip, host_port).bind(); + + auto err_no_opt = udp_send(localhost_ip, host_port, payload); + if (err_no_opt.has_value()) GTEST_SKIP() << strerror(err_no_opt.value()); + + auto result_opt = receive_once(sock, send_receive_timeout); + + ASSERT_TRUE(result_opt.has_value()); + auto const & [recv_payload, metadata] = result_opt.value(); + ASSERT_EQ(recv_payload, payload); + ASSERT_FALSE(metadata.truncated); + ASSERT_EQ(metadata.drops_since_last_receive, 0); + + // TODO(mojomex): currently cannot test timestamping on loopback interface (no timestamp produced) +} + +TEST(test_udp, test_receiving_oversized) +{ + const size_t mtu = 1500; + std::vector payload; + payload.resize(mtu + 1, 0x42); + auto sock = UdpSocket::Builder(localhost_ip, host_port).set_mtu(mtu).bind(); + + auto err_no_opt = udp_send(localhost_ip, host_port, payload); + if (err_no_opt.has_value()) GTEST_SKIP() << strerror(err_no_opt.value()); + + auto result_opt = receive_once(sock, send_receive_timeout); + + ASSERT_TRUE(result_opt.has_value()); + auto const & [recv_payload, metadata] = result_opt.value(); + ASSERT_EQ(recv_payload.size(), mtu); + ASSERT_TRUE(std::equal(recv_payload.begin(), recv_payload.end(), payload.begin())); + ASSERT_TRUE(metadata.truncated); + ASSERT_EQ(metadata.drops_since_last_receive, 0); +} + +TEST(test_udp, test_filtering_sender) +{ + std::vector payload{1, 2, 3}; + auto sock = + UdpSocket::Builder(localhost_ip, host_port).limit_to_sender(sender_ip, sender_port).bind(); + + auto err_no_opt = udp_send(localhost_ip, host_port, payload); + if (err_no_opt.has_value()) GTEST_SKIP() << strerror(err_no_opt.value()); + + auto result_opt = receive_once(sock, send_receive_timeout); + ASSERT_FALSE(result_opt.has_value()); +} + +TEST(test_udp, test_moveable) +{ + std::vector payload{1, 2, 3}; + + size_t n_received = 0; + + auto sock = UdpSocket::Builder(localhost_ip, host_port).bind(); + sock.subscribe([&n_received](const auto &, const auto &) { n_received++; }); + + auto err_no_opt = udp_send(localhost_ip, host_port, payload); + if (err_no_opt.has_value()) GTEST_SKIP() << strerror(err_no_opt.value()); + + // The subscription moves to the new socket object + UdpSocket sock2{std::move(sock)}; + ASSERT_TRUE(sock2.is_subscribed()); + + err_no_opt = udp_send(localhost_ip, host_port, payload); + if (err_no_opt.has_value()) GTEST_SKIP() << strerror(err_no_opt.value()); + + std::this_thread::sleep_for(100ms); + ASSERT_EQ(n_received, 2); +} + +} // namespace nebula::drivers::connections + +int main(int argc, char * argv[]) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}; diff --git a/nebula_hw_interfaces/test/common/test_udp/utils.hpp b/nebula_hw_interfaces/test/common/test_udp/utils.hpp new file mode 100644 index 000000000..a4d8878ce --- /dev/null +++ b/nebula_hw_interfaces/test/common/test_udp/utils.hpp @@ -0,0 +1,77 @@ +// Copyright 2024 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "nebula_hw_interfaces/nebula_hw_interfaces_common/connections/udp.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +std::optional udp_send( + const char * to_ip, uint16_t to_port, const std::vector & bytes) +{ + int sock_fd = socket(AF_INET, SOCK_DGRAM, 0); + if (sock_fd < 0) return errno; + + int enable = 1; + int result = setsockopt(sock_fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); + if (result < 0) return errno; + + sockaddr_in receiver_addr{}; + memset(&receiver_addr, 0, sizeof(receiver_addr)); + receiver_addr.sin_family = AF_INET; + receiver_addr.sin_port = htons(to_port); + receiver_addr.sin_addr.s_addr = inet_addr(to_ip); + + result = sendto( + sock_fd, bytes.data(), bytes.size(), 0, reinterpret_cast(&receiver_addr), + sizeof(receiver_addr)); + if (result < 0) return errno; + result = close(sock_fd); + if (result < 0) return errno; + return {}; +} + +template +std::optional, nebula::drivers::connections::UdpSocket::RxMetadata>> +receive_once(nebula::drivers::connections::UdpSocket & sock, std::chrono::duration<_T, _R> timeout) +{ + std::condition_variable cv_received_result; + std::mutex mtx_result; + std::optional< + std::pair, nebula::drivers::connections::UdpSocket::RxMetadata>> + result; + + sock.subscribe([&](const auto & data, const auto & metadata) { + std::lock_guard lock(mtx_result); + result.emplace(data, metadata); + cv_received_result.notify_one(); + }); + + std::unique_lock lock(mtx_result); + cv_received_result.wait_for(lock, timeout, [&result]() { return result.has_value(); }); + sock.unsubscribe(); + return result; +}