Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in MessageFilter #538

Closed
wants to merge 9 commits into from
19 changes: 2 additions & 17 deletions tf2/src/buffer_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include <assert.h>
#include <console_bridge/console.h>
#include "tf2/LinearMath/Transform.h"
#include <boost/foreach.hpp>

namespace tf2
{
Expand Down Expand Up @@ -1406,11 +1405,6 @@ void BufferCore::testTransformableRequests()
{
boost::mutex::scoped_lock lock(transformable_requests_mutex_);
V_TransformableRequest::iterator it = transformable_requests_.begin();

typedef boost::tuple<TransformableCallback&, TransformableRequestHandle, std::string,
std::string, ros::Time&, TransformableResult&> TransformableTuple;
std::vector<TransformableTuple> transformables;

for (; it != transformable_requests_.end();)
{
TransformableRequest& req = *it;
Expand Down Expand Up @@ -1450,12 +1444,8 @@ void BufferCore::testTransformableRequests()
M_TransformableCallback::iterator it = transformable_callbacks_.find(req.cb_handle);
if (it != transformable_callbacks_.end())
{
transformables.push_back(boost::make_tuple(boost::ref(it->second),
req.request_handle,
lookupFrameString(req.target_id),
lookupFrameString(req.source_id),
boost::ref(req.time),
boost::ref(result)));
const TransformableCallback& cb = it->second;
cb(req.request_handle, lookupFrameString(req.target_id), lookupFrameString(req.source_id), req.time, result);
}
}

Expand All @@ -1475,11 +1465,6 @@ void BufferCore::testTransformableRequests()
// unlock before allowing possible user callbacks to avoid potential deadlock (#91)
lock.unlock();

BOOST_FOREACH (TransformableTuple tt, transformables)
{
tt.get<0>()(tt.get<1>(), tt.get<2>(), tt.get<3>(), tt.get<4>(), tt.get<5>());
}

// Backwards compatability callback for tf
_transforms_changed_();
}
Expand Down
55 changes: 29 additions & 26 deletions tf2_ros/include/tf2_ros/message_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
~MessageFilter()
{
message_connection_.disconnect();

MessageFilter::clear();
bc_.removeTransformableCallback(callback_handle_);

TF2_ROS_MESSAGEFILTER_DEBUG("Successful Transforms: %llu, Discarded due to age: %llu, Transform messages received: %llu, Messages received: %llu, Total dropped: %llu",
(long long unsigned int)successful_transform_count_,
(long long unsigned int)failed_out_the_back_count_, (long long unsigned int)transform_message_count_,
(long long unsigned int)incoming_message_count_, (long long unsigned int)dropped_message_count_);

boost::unique_lock<boost::shared_mutex> lock(cbqueue_mutex_); // ensure that no more callback queue calls are active
}

/**
Expand Down Expand Up @@ -277,13 +278,11 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
*/
void clear()
{
boost::unique_lock< boost::shared_mutex > unique_lock(messages_mutex_);

TF2_ROS_MESSAGEFILTER_DEBUG("%s", "Cleared");

bc_.removeTransformableCallback(callback_handle_);
callback_handle_ = bc_.addTransformableCallback(boost::bind(&MessageFilter::transformable, this, _1, _2, _3, _4, _5));

// acquire after remove/addTransformableCallback to avoid deadlock!
boost::unique_lock<boost::shared_mutex> unique_lock(messages_mutex_);
messages_.clear();
message_count_ = 0;

Expand All @@ -292,6 +291,7 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
callback_queue_->removeByID((uint64_t)this);

warned_about_empty_frame_id_ = false;
TF2_ROS_MESSAGEFILTER_DEBUG("%s", "Cleared");
}

void add(const MEvent& evt)
Expand Down Expand Up @@ -363,6 +363,7 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
}
}

L_MessageInfo msgs_to_drop;

// We can transform already
if (info.success_count == expected_success_count_)
Expand All @@ -371,26 +372,13 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
}
else
{
boost::unique_lock< boost::shared_mutex > unique_lock(messages_mutex_);
boost::unique_lock<boost::shared_mutex> unique_lock(messages_mutex_);
// If this message is about to push us past our queue size, erase the oldest message
if (queue_size_ != 0 && message_count_ + 1 > queue_size_)
{
++dropped_message_count_;
const MessageInfo& front = messages_.front();
TF2_ROS_MESSAGEFILTER_DEBUG("Removed oldest message because buffer is full, count now %d (frame_id=%s, stamp=%f)", message_count_,
(mt::FrameId<M>::value(*front.event.getMessage())).c_str(), mt::TimeStamp<M>::value(*front.event.getMessage()).toSec());

V_TransformableRequestHandle::const_iterator it = front.handles.begin();
V_TransformableRequestHandle::const_iterator end = front.handles.end();

for (; it != end; ++it)
{
bc_.cancelTransformableRequest(*it);
}

messageDropped(front.event, filter_failure_reasons::Unknown);
messages_.pop_front();
--message_count_;
// move front element from messages_ to msgs_to_drop for later dropping
msgs_to_drop.splice(msgs_to_drop.begin(), messages_, messages_.begin());
--message_count_;
}

// Add the message to our list
Expand All @@ -399,6 +387,19 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
++message_count_;
}

// Delay dropping of messages until we released messages_mutex_ to avoid deadlocks (#91, #101, #144)
for (const MessageInfo &msg : msgs_to_drop)
{
++dropped_message_count_;
TF2_ROS_MESSAGEFILTER_DEBUG("Removed oldest message because buffer is full, count now %d (frame_id=%s, stamp=%f)", message_count_,
(mt::FrameId<M>::value(*msg.event.getMessage())).c_str(), mt::TimeStamp<M>::value(*msg.event.getMessage()).toSec());

for (const auto req : msg.handles)
bc_.cancelTransformableRequest(req);

messageDropped(msg.event, filter_failure_reasons::Unknown);
}

TF2_ROS_MESSAGEFILTER_DEBUG("Added message in frame %s at time %.3f, count now %d", frame_id.c_str(), stamp.toSec(), message_count_);

++incoming_message_count_;
Expand Down Expand Up @@ -461,7 +462,7 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
{
namespace mt = ros::message_traits;

boost::upgrade_lock< boost::shared_mutex > lock(messages_mutex_);
boost::upgrade_lock<boost::shared_mutex> read_lock(messages_mutex_);

// find the message this request is associated with
typename L_MessageInfo::iterator msg_it = messages_.begin();
Expand Down Expand Up @@ -524,8 +525,6 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
can_transform = false;
}

// We will be mutating messages now, require unique lock
boost::upgrade_to_unique_lock< boost::shared_mutex > uniqueLock(lock);
if (can_transform)
{
TF2_ROS_MESSAGEFILTER_DEBUG("Message ready in frame %s at time %.3f, count now %d", frame_id.c_str(), stamp.toSec(), message_count_ - 1);
Expand All @@ -543,6 +542,8 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
messageDropped(info.event, filter_failure_reasons::Unknown);
}

// We will be mutating messages now, require unique lock
boost::upgrade_to_unique_lock<boost::shared_mutex> write_lock(read_lock);
messages_.erase(msg_it);
--message_count_;
}
Expand Down Expand Up @@ -595,6 +596,7 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi

virtual CallResult call()
{
boost::shared_lock<boost::shared_mutex> lock(filter_->cbqueue_mutex_);
if (success_)
{
filter_->signalMessage(event_);
Expand Down Expand Up @@ -668,7 +670,8 @@ class MessageFilter : public MessageFilterBase, public message_filters::SimpleFi
V_string target_frames_; ///< The frames we need to be able to transform to before a message is ready
std::string target_frames_string_;
boost::mutex target_frames_mutex_; ///< A mutex to protect access to the target_frames_ list and target_frames_string.
uint32_t queue_size_; ///< The maximum number of messages we queue up
boost::shared_mutex cbqueue_mutex_; ///< A mutex protecting calls from callback queues
uint32_t queue_size_; ///< The maximum number of messages we queue up
tf2::TransformableCallbackHandle callback_handle_;

typedef std::vector<tf2::TransformableRequestHandle> V_TransformableRequestHandle;
Expand Down
120 changes: 120 additions & 0 deletions tf2_ros/test/message_filter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,126 @@ TEST(tf2_ros_message_filter, multiple_frames_and_time_tolerance)
ASSERT_TRUE(filter_callback_fired);
}

template <class M>
class MessageGenerator : public message_filters::SimpleFilter<M>
{
public:
template <typename F>
void connectInput(F &)
{
}

void add(const ros::MessageEvent<M const> &)
{
}

void generate(const std::string &frame_id, const ros::Time &time)
{
auto msg = boost::make_shared<M>();
msg->header.frame_id = frame_id;
msg->header.stamp = time;
this->signalMessage(msg);
}
};

class MessageFilterFixture : public ::testing::TestWithParam<bool>
{
using M = geometry_msgs::PointStamped;

protected:
tf2_ros::Buffer buffer;
MessageGenerator<M> source;
std::list<tf2_ros::MessageFilter<M>> filters;
bool run = true;

struct Sink
{
std::string name_;
int delay_;

Sink(const std::string &name, int delay = 0) : name_(name), delay_(delay) {}
void operator()(const boost::shared_ptr<const M> &msg)
{
std::this_thread::sleep_for(std::chrono::milliseconds(delay_));
}
};

public:
void msg_gen()
{
ros::WallRate rate(100); // publish messages @ 100Hz
const std::string frame_id("target");
while (ros::ok() && run)
{
source.generate(frame_id, ros::Time::now());
rate.sleep();
}
};

void frame_gen()
{
ros::WallRate rate(50); // publish frame info @ 50 Hz (slower than msgs)
while (ros::ok() && run)
{
geometry_msgs::TransformStamped transform;
transform.header.stamp = ros::Time::now();
transform.header.frame_id = "base";
transform.child_frame_id = "target";
transform.transform.translation.x = 0.0;
transform.transform.translation.y = 0.0;
transform.transform.translation.z = 0.0;
transform.transform.rotation.x = 0.0;
transform.transform.rotation.y = 0.0;
transform.transform.rotation.z = 0.0;
transform.transform.rotation.w = 1.0;
buffer.setTransform(transform, "frame_generator", false);
rate.sleep();
}
};

void add_filter(int i, ros::CallbackQueueInterface *queue)
{
std::string name(queue ? "Q" : "S");
name += std::to_string(i);

filters.emplace_back(buffer, "base", i + 1, queue);
auto &f = filters.back();
f.setName(name);
f.connectInput(source);
f.registerCallback(Sink(name, 1));
};
};

TEST_P(MessageFilterFixture, StressTest)
{
ros::NodeHandle nh;
ros::AsyncSpinner spinner(1);
spinner.start();

std::thread msg_gen(&MessageFilterFixture::msg_gen, this);
std::thread frame_gen(&MessageFilterFixture::frame_gen, this);

bool use_cbqueue = GetParam();
ros::CallbackQueueInterface *queue = use_cbqueue ? nh.getCallbackQueue() : nullptr;
// use fewer filters for signal-only transmission as we can remove only a single filter per iteration
int num_filters = use_cbqueue ? 50 : 10;
for (int i = 0; i < num_filters; ++i)
add_filter(i, queue);

// slowly remove filters
std::this_thread::sleep_for(std::chrono::milliseconds(20));
while (!filters.empty())
{
std::this_thread::sleep_for(std::chrono::milliseconds(7));
filters.pop_front();
}

run = false;
msg_gen.join();
frame_gen.join();
}
INSTANTIATE_TEST_CASE_P(MessageFilterTests, MessageFilterFixture, ::testing::Values(false, true));

int main(int argc, char **argv){
testing::InitGoogleTest(&argc, argv);
ros::init(argc, argv, "tf2_ros_message_filter");
Expand Down