Skip to content

Commit

Permalink
Make more of Node thread-safe (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
domfarolino authored Aug 29, 2023
1 parent a4db9ef commit d1f1a90
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 17 deletions.
49 changes: 49 additions & 0 deletions mage/core/core_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ class CoreUnitTest : public testing::TestWithParam<MainThreadType> {
io_thread.StopWhenIdle(); // Blocks.
main_thread.reset();
dummy_launcher.reset();
for (int& socket : extra_sockets_) {
EXPECT_EQ(close(socket), 0);
}
}

void GetSockets(int num_sockets, std::vector<int>& return_sockets) {
// This method is only called once per test.
ASSERT_EQ(extra_sockets_.size(), 0);

for (int i = 0; i < num_sockets; ++i) {
int socket_pair[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, socket_pair), 0);
ASSERT_EQ(fcntl(socket_pair[0], F_SETFL), 0);
ASSERT_EQ(fcntl(socket_pair[1], F_SETFL), 0);

return_sockets.push_back(socket_pair[0]);
extra_sockets_.push_back(socket_pair[0]);
extra_sockets_.push_back(socket_pair[1]);
}
}

protected:
Expand All @@ -119,6 +138,7 @@ class CoreUnitTest : public testing::TestWithParam<MainThreadType> {
std::unique_ptr<DummyProcessLauncher> dummy_launcher;
std::shared_ptr<base::TaskLoop> main_thread;
base::Thread io_thread;
std::vector<int> extra_sockets_;
};

INSTANTIATE_TEST_SUITE_P(All,
Expand Down Expand Up @@ -312,4 +332,33 @@ TEST_P(CoreUnitTest, CreateMessagePipesFromAnyThread) {
EXPECT_EQ(NodeLocalEndpoints().size(), 20000);
}

// This test spins up N threads that send ten invitations each, all on different
// sockets.
TEST_P(CoreUnitTest, MultiThreadRacyInvitationSending) {
const int kNumThreads = 10;
const int kNumInvitationsPerThread = 8;
// Get one socket per thread.
std::vector<int> sockets;
GetSockets(kNumThreads * kNumInvitationsPerThread, sockets);
ASSERT_EQ(sockets.size(), kNumThreads * kNumInvitationsPerThread);

std::vector<std::unique_ptr<base::Thread>> worker_threads;
for (int i = 0; i < kNumThreads; ++i) {
worker_threads.push_back(std::make_unique<base::Thread>(base::ThreadType::WORKER));
worker_threads[i]->Start();
worker_threads[i]->GetTaskRunner()->PostTask([i, &sockets](){
for (int j = 0; j < kNumInvitationsPerThread; ++j) {
mage::Core::SendInvitationAndGetMessagePipe(sockets[i * kNumInvitationsPerThread + j]);
}
});
}

for (auto& worker_thread : worker_threads) {
worker_thread->StopWhenIdle();
}

EXPECT_EQ(CoreHandleTable().size(), kNumThreads * kNumInvitationsPerThread * 2);
EXPECT_EQ(NodeLocalEndpoints().size(), kNumThreads * kNumInvitationsPerThread * 2);
}

}; // namespace mage
72 changes: 56 additions & 16 deletions mage/core/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ MessagePipe Node::SendInvitationAndGetMessagePipe(int fd) {

NodeName temporary_remote_node_name = util::RandomIdentifier();

// TODO(domfarolino): Probably lock `node_channel_map_` here, since another
// thread could be accessing this at the same time when sending a remote
// message.
// Lock this map because it could be accessed from many threads, since the
// embedder can choose to send invitations from any thread.
node_channel_map_lock_.lock();
auto it = node_channel_map_.insert(
{temporary_remote_node_name,
std::make_unique<Channel>(fd, /*delegate=*/this)});
// TODO(domfarolino): Maybe lock this?
node_channel_map_lock_.unlock();
pending_invitations_lock_.lock();
pending_invitations_.insert({temporary_remote_node_name, remote_endpoint});
pending_invitations_lock_.unlock();

// Similar to `AcceptInvitation()` below, only start the channel after it and
// `remote_endpoint` have been inserted into their maps. Right when we the
Expand All @@ -148,11 +150,30 @@ void Node::AcceptInvitation(int fd) {

LOG("Node::AcceptInvitation() getpid: %d", getpid());
std::unique_ptr<Channel> channel(new Channel(fd, this));
// Thread-safety: We don't need to take a lock on this map here because in
// general, nodes that accept invitations should not be sending them too.
// Doing so is technically possible for now, but is considered "unsupported",
// and therefore we don't go out of our way to make it safe. See the
// corresponding documentation in `Node::OnReceivedInvitation()`, which gets
// invoked asynchronously after we call `Channel::Start()` below.
auto it = node_channel_map_.insert({kInitialChannelName, std::move(channel)});

// Start the channel *after* it is inserted into the map, because right when
// it starts, the IO thread could try and read from `node_channel_map_` at any
// time, since messages will start coming in (for example,
// `OnReceivedInvitation()`).
// it starts, messages could start coming in and being read from any other
// thread. This is because `Channel::Start()` defers to the embedder-supplied
// IO mechanism for receiving messages on `fd`, and that mechanism might
// always listen for (and read) messages on whatever thread that IO mechanism
// is bound to, which may be different than *this* thread.
//
// For example, if the embedder is using the `//base` library [1], there's
// only ever a single "IO" thread capable of socket communication at any given
// time in a process. If this method is run on any other thread than the IO
// thread, then right when `Channel::Start()` is invoked on *this* thread,
// incoming messages like `OnReceivedInvitation()` could start coming in on
// the IO thread at any time, and expect `channel` to exist in
// `node_channel_map_` under the `kInitialChannelName` name.
//
// [1]: https://github.com/domfarolino/base.
it.first->second->Start();

has_accepted_invitation_ = true;
Expand Down Expand Up @@ -260,8 +281,6 @@ void Node::SendMessage(std::shared_ptr<Endpoint> local_endpoint,
}

void Node::OnReceivedMessage(Message message) {
CHECK_ON_THREAD(base::ThreadType::IO);

switch (message.Type()) {
case MessageType::SEND_INVITATION:
OnReceivedInvitation(std::move(message));
Expand All @@ -278,7 +297,6 @@ void Node::OnReceivedMessage(Message message) {
}

void Node::OnReceivedInvitation(Message message) {
CHECK_ON_THREAD(base::ThreadType::IO);
SendInvitationParams* params = message.GetView<SendInvitationParams>();

// Deserialize
Expand All @@ -298,6 +316,13 @@ void Node::OnReceivedInvitation(Message message) {

// Now that we know our inviter's name, we can find our initial channel in our
// map, and change the entry's key to the actual inviter's name.
//
// Thread-safety: We don't need to take a lock over this map. That's because a
// node can only accept a single invitation throughout its life, and
// invitation-accepting nodes shouldn't be able to *send* invitations (see the
// documentation in `Node::AcceptInvitation()`). This is safe because at this
// point, `this` can't know about any other nodes/channels yet, so this map is
// essentially static at this point.
auto it = node_channel_map_.find(kInitialChannelName);
CHECK_NE(it, node_channel_map_.end());
std::unique_ptr<Channel> init_channel = std::move(it->second);
Expand Down Expand Up @@ -330,7 +355,6 @@ void Node::OnReceivedInvitation(Message message) {
}

void Node::OnReceivedAcceptInvitation(Message message) {
CHECK_ON_THREAD(base::ThreadType::IO);
SendAcceptInvitationParams* params =
message.GetView<SendAcceptInvitationParams>();
std::string temporary_remote_node_name(
Expand All @@ -347,13 +371,18 @@ void Node::OnReceivedAcceptInvitation(Message message) {

// We should only get ACCEPT_INVITATION messages from nodes that we have a
// pending invitation for.
//
// In order to acknowledge the invitation acceptance, we must do four things:
// 1.) Remove the pending invitation from |pending_invitations_|.
pending_invitations_lock_.lock();
auto remote_endpoint_it =
pending_invitations_.find(temporary_remote_node_name);
CHECK_NE(remote_endpoint_it, pending_invitations_.end());
std::shared_ptr<Endpoint> remote_endpoint = remote_endpoint_it->second;
pending_invitations_.erase(temporary_remote_node_name);
pending_invitations_lock_.unlock();

// In order to acknowledge the invitation acceptance, we must do four things:
// 1.) Put |remote_endpoint| in the `kUnboundAndProxying` state, so that
// 2.) Put |remote_endpoint| in the `kUnboundAndProxying` state, so that
// when `SendMessage()` gets message bound for it, it knows to forward
// them to the appropriate remote node.
CHECK_NE(local_endpoints_.find(remote_endpoint->name),
Expand All @@ -372,16 +401,19 @@ void Node::OnReceivedAcceptInvitation(Message message) {
remote_endpoint->proxy_target.node_name.c_str(),
remote_endpoint->proxy_target.endpoint_name.c_str());

// 2.) Remove the pending invitation from |pending_invitations_|.
pending_invitations_.erase(temporary_remote_node_name);

// 3.) Update |node_channel_map_| to correctly be keyed off of
// |actual_node_name|.
//
// Lock the node channel map because in the path of acknowledging an
// accepted invitation, other threads could be sending other invitations
// and mutating the same map.
node_channel_map_lock_.lock();
auto node_channel_it = node_channel_map_.find(temporary_remote_node_name);
CHECK_NE(node_channel_it, node_channel_map_.end());
std::unique_ptr<Channel> channel = std::move(node_channel_it->second);
node_channel_map_.erase(temporary_remote_node_name);
node_channel_map_.insert({actual_node_name, std::move(channel)});
node_channel_map_lock_.unlock();

// 4.) Forward any messages that were queued in |remote_endpoint| so that
// the remote node's endpoint gets them. Note that the messages queued
Expand Down Expand Up @@ -506,8 +538,10 @@ void Node::SendMessagesAndRecursiveDependents(
}

// Forward the message and remove it from the queue.
node_channel_map_lock_.lock();
node_channel_map_[target_node_name]->SendMessage(
std::move(message_to_send));
node_channel_map_lock_.unlock();

// See the documentation above `locked_dependent_endpoints`.
for (const std::shared_ptr<Endpoint>& endpoint : locked_dependent_endpoints)
Expand Down Expand Up @@ -562,8 +596,14 @@ void Node::OnReceivedUserMessage(Message message) {
memcpy(message.GetMutableMessageHeader().target_endpoint,
proxy_target.endpoint_name.c_str(), kIdentifierSize);
PrepareToForwardUserMessage(endpoint, message);
// Lock the node channel map because the thread we receive a message a
// message on (and therefore forward messages to another node in this
// proxying case) could be different than the arbitrary threads that might
// be sending invitations to other processes.
node_channel_map_lock_.lock();
node_channel_map_[endpoint->proxy_target.node_name]->SendMessage(
std::move(message));
node_channel_map_lock_.unlock();
break;
}
case Endpoint::State::kBound:
Expand Down
3 changes: 2 additions & 1 deletion mage/core/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class Node : public Channel::Delegate {

// Thread-safe.
std::vector<MessagePipe> CreateMessagePipes();

MessagePipe SendInvitationAndGetMessagePipe(int fd);
void AcceptInvitation(int fd);
void SendMessage(std::shared_ptr<Endpoint> local_endpoint, Message message);
Expand Down Expand Up @@ -82,6 +81,7 @@ class Node : public Channel::Delegate {
// endpoint in the same node (that is, from an endpoint in Node A to its peer
// endpoint also in Node A) go through a different path.
std::map<NodeName, std::unique_ptr<Channel>> node_channel_map_;
base::Mutex node_channel_map_lock_;

// Maps |NodeNames| that we've sent invitations to and are awaiting
// acceptances from, to an |Endpoint| that we've reserved for the peer node.
Expand All @@ -91,6 +91,7 @@ class Node : public Channel::Delegate {
// we've given it, we update instances of its temporary name with its "real"
// one that it provides in the invitation acceptance message.
std::map<NodeName, std::shared_ptr<Endpoint>> pending_invitations_;
base::Mutex pending_invitations_lock_;
};

}; // namespace mage
Expand Down
1 change: 1 addition & 0 deletions mage/public/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ std::vector<MessagePipe> CreateMessagePipes();
MessagePipe SendInvitationAndGetMessagePipe(
int fd,
base::OnceClosure callback = base::OnceClosure());
// Should only be called once per process (generally right after `Init()`).
void AcceptInvitation(
int fd,
std::function<void(MessagePipe)> finished_accepting_invitation_callback);
Expand Down

0 comments on commit d1f1a90

Please sign in to comment.