diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h index 95f6cba17..87a47ca3b 100644 --- a/include/faabric/transport/PointToPointBroker.h +++ b/include/faabric/transport/PointToPointBroker.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -120,27 +121,16 @@ class PointToPointBroker void updateHostForIdx(int groupId, int groupIdx, std::string newHost); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg = false); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, bool mustOrderMsg = false, int sequenceNum = NO_SEQUENCE_NUM, std::string hostHint = ""); - std::vector recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg = false); + void recvMessage(PointToPointMessage& msg, bool mustOrderMsg = false); void clearGroup(int groupId); @@ -163,7 +153,8 @@ class PointToPointBroker std::shared_ptr getGroupFlag(int groupId); - Message doRecvMessage(int groupId, int sendIdx, int recvIdx); + // Returns the message response code and the sequence number + std::pair doRecvMessage(PointToPointMessage& msg); void initSequenceCounters(int groupId); diff --git a/include/faabric/transport/PointToPointClient.h b/include/faabric/transport/PointToPointClient.h index 634b41579..5e5add933 100644 --- a/include/faabric/transport/PointToPointClient.h +++ b/include/faabric/transport/PointToPointClient.h @@ -3,18 +3,19 @@ #include #include #include +#include namespace faabric::transport { std::vector> getSentMappings(); -std::vector> +std::vector> getSentPointToPointMessages(); std::vector> + PointToPointMessage>> getSentLockMessages(); void clearSentMessages(); @@ -26,7 +27,7 @@ class PointToPointClient : public faabric::transport::MessageEndpointClient void sendMappings(faabric::PointToPointMappings& mappings); - void sendMessage(faabric::PointToPointMessage& msg, + void sendMessage(const PointToPointMessage& msg, int sequenceNum = NO_SEQUENCE_NUM); void groupLock(int appId, diff --git a/include/faabric/transport/PointToPointMessage.h b/include/faabric/transport/PointToPointMessage.h new file mode 100644 index 000000000..e61e2c509 --- /dev/null +++ b/include/faabric/transport/PointToPointMessage.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +namespace faabric::transport { + +/* Simple fixed-size C-struct to capture the state of a PTP message moving + * through Faabric. + * + * We require fixed-size, and no unique pointers to be able to use + * high-throughput ring-buffers to send the messages around. This also means + * that we manually malloc/free the data pointer. The message size is: + * 4 * int32_t = 4 * 4 bytes = 16 bytes + * 1 * size_t = 1 * 8 bytes = 8 bytes + * 1 * void* = 1 * 8 bytes = 8 bytes + * total = 32 bytes = 4 * 8 so the struct is naturally 8 byte-aligned + */ +struct PointToPointMessage +{ + int32_t appId; + int32_t groupId; + int32_t sendIdx; + int32_t recvIdx; + size_t dataSize; + void* dataPtr; +}; +static_assert((sizeof(PointToPointMessage) % 8) == 0, + "PTP message mus be 8-aligned!"); + +// The wire format for a PTP message is very simple: the fixed-size struct, +// followed by dataSize bytes containing the payload. +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg); + +// This parsing function mallocs space for the message payload. This is to +// keep the PTP message at fixed-size, and be able to efficiently move it +// around in-memory queues +void parsePtpMsg(std::span bytes, PointToPointMessage* msg); + +// Alternative signature for parsing PTP messages for when the caller can +// provide an already-allocated buffer to write into +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer); +} diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index 5b0887c7a..9839d7f96 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -59,14 +60,16 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, serializeMpiMsg(serialisedBuffer, msg); try { - broker.sendMessage( - thisRankMsg->groupid(), - sendRank, - recvRank, - reinterpret_cast(serialisedBuffer.data()), - serialisedBuffer.size(), - dstHost, - true); + // It is safe to send a pointer to a stack-allocated object + // because the broker will make an additional copy (and so will NNG!) + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank, + .dataSize = serialisedBuffer.size(), + .dataPtr = (void*)serialisedBuffer.data() }); + + broker.sendMessage(msg, dstHost, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - send {} -> {}", thisRankMsg->appid(), @@ -80,10 +83,12 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { - std::vector msg; + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank }); try { - msg = - broker.recvMessage(thisRankMsg->groupid(), sendRank, recvRank, true); + broker.recvMessage(msg, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - recv (remote) {} -> {}", thisRankMsg->appid(), @@ -96,7 +101,8 @@ MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) // TODO(mpi-opt): make sure we minimze copies here MpiMessage parsedMsg; - parseMpiMsg(msg, &parsedMsg); + std::vector msgBytes((uint8_t*) msg.dataPtr, (uint8_t*) msg.dataPtr + msg.dataSize); + parseMpiMsg(msgBytes, &parsedMsg); return parsedMsg; } diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 5daa2b5cb..8ed729a8e 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -199,14 +199,6 @@ message StateAppendedResponse { // POINT-TO-POINT // --------------------------------------------- -message PointToPointMessage { - int32 appId = 1; - int32 groupId = 2; - int32 sendIdx = 3; - int32 recvIdx = 4; - bytes data = 5; -} - message PointToPointMappings { int32 appId = 1; int32 groupId = 2; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 63a87c49d..6ce6f4afd 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -459,12 +459,32 @@ Scheduler::checkForMigrationOpportunities(faabric::Message& msg, auto groupIdxs = broker.getIdxsRegisteredForGroup(groupId); groupIdxs.erase(0); for (const auto& recvIdx : groupIdxs) { - broker.sendMessage( - groupId, 0, recvIdx, BYTES_CONST(&newGroupId), sizeof(int)); + // It is safe to send a pointer to the stack, because the + // transport layer will perform an additional copy of the PTP + // message to put it in the message body + // TODO(no-inproc): this may not be true once we move the inproc + // sockets to in-memory queues + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, + .sendIdx = 0, + .recvIdx = recvIdx, + .dataSize = sizeof(int), + .dataPtr = &newGroupId }); + + broker.sendMessage(msg); } } else if (overwriteNewGroupId == 0) { - std::vector bytes = broker.recvMessage(groupId, 0, groupIdx); + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, .sendIdx = 0, .recvIdx = groupIdx }); + // TODO(no-order): when we remove the need to order ptp messages we + // should be able to call recv giving it a pre-allocated buffer, + // avoiding the hassle of malloc-ing and free-ing + broker.recvMessage(msg); + std::vector bytes((uint8_t*)msg.dataPtr, + (uint8_t*)msg.dataPtr + msg.dataSize); newGroupId = faabric::util::bytesToInt(bytes); + // The previous call makes a copy, so safe to free now + faabric::util::free(msg.dataPtr); } else { // In some settings, like tests, we already know the new group id, so // we can set it here (and in fact, we need to do so when faking two diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index e8b7c339d..e68fa72bd 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -9,6 +9,7 @@ faabric_lib(transport MessageEndpointServer.cpp PointToPointBroker.cpp PointToPointClient.cpp + PointToPointMessage.cpp PointToPointServer.cpp ) diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 7984c951b..4bee05763 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -36,6 +36,7 @@ void MessageEndpointClient::asyncSend(int header, sequenceNum); } +// TODO: consider making an iovec-style scatter/gather alternative signature void MessageEndpointClient::asyncSend(int header, const uint8_t* buffer, size_t bufferSize, diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 9581fc27a..d2c5a0cc3 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -53,7 +53,8 @@ thread_local std::vector sentMsgCount; thread_local std::vector recvMsgCount; -thread_local std::vector> outOfOrderMsgs; +thread_local std::vector>> + outOfOrderMsgs; static std::shared_ptr getClient(const std::string& host) { @@ -202,8 +203,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) groupId, recursive); - ptpBroker.recvMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } else { // Notify remote locker that they've acquired the lock SPDLOG_TRACE( @@ -217,10 +222,6 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) } } else { auto cli = getClient(mainHost); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote lock {}:{}:{} to {}", groupId, @@ -232,7 +233,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) // acquired cli->groupLock(appId, groupId, groupIdx, recursive); - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } } @@ -285,10 +291,6 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote unlock {}:{}:{} to {}", groupId, @@ -308,9 +310,13 @@ void PointToPointGroup::localUnlock() void PointToPointGroup::notifyLocked(int groupIdx) { std::vector data(1, 0); - - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg = { .appId = 0, + .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }; + ptpBroker.sendMessage(msg); } void PointToPointGroup::barrier(int groupIdx) @@ -324,23 +330,40 @@ void PointToPointGroup::barrier(int groupIdx) if (groupIdx == POINT_TO_POINT_MAIN_IDX) { // Receive from all for (int i = 1; i < groupSize; i++) { - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } // Reply to all std::vector data(1, 0); for (int i = 1; i < groupSize; i++) { - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, i, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = i, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } else { // Do the send - std::vector data(1, 0); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); // Await the response - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage response({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(response); } } @@ -351,15 +374,23 @@ void PointToPointGroup::notify(int groupIdx) SPDLOG_TRACE( "Master group {} waiting for notify from index {}", groupId, i); - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); SPDLOG_TRACE("Master group {} notified by index {}", groupId, i); } } else { - std::vector data(1, 0); SPDLOG_TRACE("Notifying group {} from index {}", groupId, groupIdx); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } @@ -581,22 +612,11 @@ void PointToPointBroker::updateHostForIdx(int groupId, mappings[key] = newHost; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg) { - sendMessage(groupId, - sendIdx, - recvIdx, - buffer, - bufferSize, - mustOrderMsg, - NO_SEQUENCE_NUM, - hostHint); + sendMessage(msg, mustOrderMsg, NO_SEQUENCE_NUM, hostHint); } // Gets or creates a pair of inproc endpoints (recv&send) in the endpoints map. @@ -634,11 +654,7 @@ auto getEndpointPtrs(const std::string& label) return endpointPtrs; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, bool mustOrderMsg, int sequenceNum, std::string hostHint) @@ -647,19 +663,21 @@ void PointToPointBroker::sendMessage(int groupId, // sender thread, and another time from the point-to-point server to route // it to the receiver thread - waitForMappingsOnThisHost(groupId); + waitForMappingsOnThisHost(msg.groupId); // If the application code knows which host does the receiver live in // (cached for performance) we allow it to provide a hint to avoid // acquiring a shared lock here - std::string host = - hostHint.empty() ? getHostForReceiver(groupId, recvIdx) : hostHint; + std::string host = hostHint.empty() + ? getHostForReceiver(msg.groupId, msg.recvIdx) + : hostHint; // Set the sequence number if we need ordering and one is not provided bool mustSetSequenceNum = mustOrderMsg && sequenceNum == NO_SEQUENCE_NUM; if (host == conf.endpointHost) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = @@ -671,46 +689,49 @@ void PointToPointBroker::sendMessage(int groupId, // the sender thread we add a sequence number (if needed) int localSendSeqNum = sequenceNum; if (mustSetSequenceNum) { - localSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + localSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Local point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); try { - endpoint.send(NO_HEADER, buffer, bufferSize, localSendSeqNum); + // TODO(no-inproc): once we convert the inproc endpoints to a queue + // we should be able to just push the whole message to the queue + std::vector buffer(sizeof(PointToPointMessage) + + msg.dataSize); + serializePtpMsg(buffer, msg); + endpoint.send( + NO_HEADER, buffer.data(), buffer.size(), localSendSeqNum); } catch (std::runtime_error& e) { SPDLOG_ERROR("Timed-out with local point-to-point message {}:{}:{} " "(seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); throw e; } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(sendIdx); - msg.set_recvidx(recvIdx); - msg.set_data(buffer, bufferSize); // When sending a remote message, we set a sequence number if required int remoteSendSeqNum = NO_SEQUENCE_NUM; if (mustSetSequenceNum) { - remoteSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + remoteSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Remote point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); @@ -719,59 +740,81 @@ void PointToPointBroker::sendMessage(int groupId, } catch (std::runtime_error& e) { SPDLOG_TRACE("Timed-out with remote point-to-point message " "{}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); } } } -Message PointToPointBroker::doRecvMessage(int groupId, int sendIdx, int recvIdx) +std::pair PointToPointBroker::doRecvMessage( + PointToPointMessage& msg) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = *std::get>( *endpointPtrs); - return endpoint.recv(); + // TODO(no-inproc): this will become a pop from a queue, not a read from + // an in-proc socket + Message bytes = endpoint.recv(); + + // WARNING: this call mallocs + parsePtpMsg(bytes.udata(), &msg); + + /* TODO(no-order): for the moment always parse and malloc memory, as it is + * not easy to track when did we malloc or not. This is gonna become + * simpler once we remove the need to order messages in the PTP layer + * + if (hasPreAllocBuffer) { + std::span msgDataSpan((uint8_t*) msg.dataPtr, msg.dataSize); + parsePtpMsg(bytes.udata(), &msg, msgDataSpan); + } else { + parsePtpMsg(bytes.udata(), &msg); + } + */ + + assert(getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx) == label); + + return std::make_pair(bytes.getResponseCode(), + bytes.getSequenceNum()); } -std::vector PointToPointBroker::recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg) +void PointToPointBroker::recvMessage(PointToPointMessage& msg, + bool mustOrderMsg) { // If we don't need to receive messages in order, return here if (!mustOrderMsg) { - // TODO - can we avoid this copy? - return doRecvMessage(groupId, sendIdx, recvIdx).dataCopy(); + doRecvMessage(msg); + return; } // Get the sequence number we expect to receive - int expectedSeqNum = getExpectedSeqNum(groupId, sendIdx); + int expectedSeqNum = getExpectedSeqNum(msg.groupId, msg.sendIdx); // We first check if we have already received the message. We only need to // check this once. - auto foundIterator = - std::find_if(outOfOrderMsgs.at(sendIdx).begin(), - outOfOrderMsgs.at(sendIdx).end(), - [expectedSeqNum](const Message& msg) { - return msg.getSequenceNum() == expectedSeqNum; - }); - if (foundIterator != outOfOrderMsgs.at(sendIdx).end()) { + auto foundIterator = std::find_if( + outOfOrderMsgs.at(msg.sendIdx).begin(), + outOfOrderMsgs.at(msg.sendIdx).end(), + [expectedSeqNum](const std::pair& pair) { + return pair.first == expectedSeqNum; + }); + if (foundIterator != outOfOrderMsgs.at(msg.sendIdx).end()) { SPDLOG_TRACE("Retrieved the expected message ({}:{} seq: {}) from the " "out-of-order buffer", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - Message returnMsg = std::move(*foundIterator); - outOfOrderMsgs.at(sendIdx).erase(foundIterator); - return returnMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + msg = foundIterator->second; + outOfOrderMsgs.at(msg.sendIdx).erase(foundIterator); + return; } // Given that we don't have the message, we query the transport layer until @@ -779,47 +822,52 @@ std::vector PointToPointBroker::recvMessage(int groupId, while (true) { SPDLOG_TRACE( "Entering loop to query transport layer for msg ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - // Receive from the transport layer - Message recvMsg = doRecvMessage(groupId, sendIdx, recvIdx); + + // Receive from the transport layer with the same group id and + // send/recv indexes + PointToPointMessage tmpMsg({ .groupId = msg.groupId, + .sendIdx = msg.sendIdx, + .recvIdx = msg.recvIdx }); + auto [responseCode, seqNum] = doRecvMessage(tmpMsg); // If the receive was not successful, exit the loop - if (recvMsg.getResponseCode() != - faabric::transport::MessageResponseCode::SUCCESS) { + if (responseCode != faabric::transport::MessageResponseCode::SUCCESS) { SPDLOG_WARN( "Error {} ({}) when awaiting a message ({}:{} seq: {} label: {})", - static_cast(recvMsg.getResponseCode()), - MessageResponseCodeText.at(recvMsg.getResponseCode()), - sendIdx, - recvIdx, + static_cast(responseCode), + MessageResponseCodeText.at(responseCode), + msg.sendIdx, + msg.recvIdx, expectedSeqNum, - getPointToPointKey(groupId, sendIdx, recvIdx)); + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx)); throw std::runtime_error("Error when awaiting a PTP message"); } // If the sequence numbers match, exit the loop - int seqNum = recvMsg.getSequenceNum(); if (seqNum == expectedSeqNum) { SPDLOG_TRACE("Received the expected message ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - return recvMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + + msg = tmpMsg; + return; } // If not, we must insert the received message in the out of order // received messages SPDLOG_TRACE("Received out-of-order message ({}:{} seq: {}) (expected: " "{} - got: {})", - sendIdx, - recvIdx, + tmpMsg.sendIdx, + tmpMsg.recvIdx, seqNum, expectedSeqNum, seqNum); - outOfOrderMsgs.at(sendIdx).emplace_back(std::move(recvMsg)); + outOfOrderMsgs.at(tmpMsg.sendIdx).emplace_back(seqNum, tmpMsg); } } @@ -874,10 +922,10 @@ void PointToPointBroker::resetThreadLocalCache() void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) { + /* int postMigrationOkCode = 1337; int recvCode = 0; - // TODO: implement this as a broadcast in the PTP broker int mainIdx = 0; if (groupIdx == mainIdx) { auto groupIdxs = getIdxsRegisteredForGroup(groupId); @@ -902,6 +950,8 @@ void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) recvCode); throw std::runtime_error("Error in post-migration hook"); } + */ + PointToPointGroup::getGroup(groupId)->barrier(groupIdx); SPDLOG_DEBUG("{}:{} exiting post-migration hook", groupId, groupIdx); } diff --git a/src/transport/PointToPointClient.cpp b/src/transport/PointToPointClient.cpp index d0b7188f8..506fc9874 100644 --- a/src/transport/PointToPointClient.cpp +++ b/src/transport/PointToPointClient.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -13,12 +14,11 @@ static std::mutex mockMutex; static std::vector> sentMappings; -static std::vector> - sentMessages; +static std::vector> sentMessages; static std::vector> + PointToPointMessage>> sentLockMessages; std::vector> @@ -27,7 +27,7 @@ getSentMappings() return sentMappings; } -std::vector> +std::vector> getSentPointToPointMessages() { return sentMessages; @@ -35,7 +35,7 @@ getSentPointToPointMessages() std::vector> + PointToPointMessage>> getSentLockMessages() { return sentLockMessages; @@ -64,13 +64,18 @@ void PointToPointClient::sendMappings(faabric::PointToPointMappings& mappings) } } -void PointToPointClient::sendMessage(faabric::PointToPointMessage& msg, +void PointToPointClient::sendMessage(const PointToPointMessage& msg, int sequenceNum) { if (faabric::util::isMockMode()) { sentMessages.emplace_back(host, msg); } else { - asyncSend(PointToPointCall::MESSAGE, &msg, sequenceNum); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(msg) + msg.dataSize); + serializePtpMsg(buffer, msg); + asyncSend( + PointToPointCall::MESSAGE, buffer.data(), buffer.size(), sequenceNum); } } @@ -80,11 +85,12 @@ void PointToPointClient::makeCoordinationRequest( int groupIdx, faabric::transport::PointToPointCall call) { - faabric::PointToPointMessage req; - req.set_appid(appId); - req.set_groupid(groupId); - req.set_sendidx(groupIdx); - req.set_recvidx(POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req({ .appId = appId, + .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); switch (call) { case (faabric::transport::PointToPointCall::LOCK_GROUP): { @@ -115,7 +121,11 @@ void PointToPointClient::makeCoordinationRequest( faabric::util::UniqueLock lock(mockMutex); sentLockMessages.emplace_back(host, call, req); } else { - asyncSend(call, &req); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(PointToPointMessage) + req.dataSize); + serializePtpMsg(buffer, req); + asyncSend(call, buffer.data(), buffer.size()); } } diff --git a/src/transport/PointToPointMessage.cpp b/src/transport/PointToPointMessage.cpp new file mode 100644 index 000000000..e9415db11 --- /dev/null +++ b/src/transport/PointToPointMessage.cpp @@ -0,0 +1,62 @@ +#include +#include + +#include +#include +#include + +namespace faabric::transport { + +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg) +{ + assert(buffer.size() == sizeof(PointToPointMessage) + msg.dataSize); + std::memcpy(buffer.data(), &msg, sizeof(PointToPointMessage)); + + if (msg.dataSize > 0 && msg.dataPtr != nullptr) { + std::memcpy(buffer.data() + sizeof(PointToPointMessage), + msg.dataPtr, + msg.dataSize); + } +} + +// Parse all the fixed-size parts of the struct +static void parsePtpMsgCommon(std::span bytes, + PointToPointMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(PointToPointMessage)); + std::memcpy(msg, bytes.data(), sizeof(PointToPointMessage)); + size_t thisDataSize = bytes.size() - sizeof(PointToPointMessage); + assert(thisDataSize == msg->dataSize); + + if (thisDataSize == 0) { + msg->dataPtr = nullptr; + } +} + +void parsePtpMsg(std::span bytes, PointToPointMessage* msg) +{ + parsePtpMsgCommon(bytes, msg); + + if (msg->dataSize == 0) { + return; + } + + // malloc memory for the PTP message payload + msg->dataPtr = faabric::util::malloc(msg->dataSize); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} + +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer) +{ + parsePtpMsgCommon(bytes, msg); + + assert(msg->dataSize == preAllocBuffer.size()); + msg->dataPtr = preAllocBuffer.data(); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} +} diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp index 173fec0bf..6224eed84 100644 --- a/src/transport/PointToPointServer.cpp +++ b/src/transport/PointToPointServer.cpp @@ -1,12 +1,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace faabric::transport { @@ -25,9 +27,11 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) int sequenceNum = message.getSequenceNum(); switch (header) { case (faabric::transport::PointToPointCall::MESSAGE): { - PARSE_MSG(faabric::PointToPointMessage, - message.udata().data(), - message.udata().size()) + // Here we are copying the message from the transport layer (NNG) + // into our PTP message structure + // NOTE: this mallocs + PointToPointMessage parsedMsg; + parsePtpMsg(message.udata(), &parsedMsg); // If the sequence number is set, we must also set the ordering // flag @@ -35,13 +39,15 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) // Send the message locally to the downstream socket, add the // sequence number for in-order reception - broker.sendMessage(parsedMsg.groupid(), - parsedMsg.sendidx(), - parsedMsg.recvidx(), - BYTES_CONST(parsedMsg.data().c_str()), - parsedMsg.data().size(), - mustOrderMsg, - sequenceNum); + broker.sendMessage(parsedMsg, mustOrderMsg, sequenceNum); + + // TODO(no-inproc): for the moment, the downstream (inproc) + // socket makes a copy of this message, so we can free it now + // after sending. This will not be the case once we move to + // in-memory queues + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } break; } case faabric::transport::PointToPointCall::LOCK_GROUP: { @@ -101,28 +107,33 @@ std::unique_ptr PointToPointServer::doRecvMappings( void PointToPointServer::recvGroupLock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); + SPDLOG_TRACE("Receiving lock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->lock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->lock(parsedMsg.sendIdx, recursive); } void PointToPointServer::recvGroupUnlock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); SPDLOG_TRACE("Receiving unlock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->unlock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->unlock(parsedMsg.sendIdx, recursive); } void PointToPointServer::onWorkerStop() diff --git a/tests/dist/transport/functions.cpp b/tests/dist/transport/functions.cpp index 8c99b05b2..1485f5f47 100644 --- a/tests/dist/transport/functions.cpp +++ b/tests/dist/transport/functions.cpp @@ -4,9 +4,9 @@ #include "faabric_utils.h" #include "init.h" -#include #include #include +#include #include #include #include @@ -43,12 +43,25 @@ int handlePointToPointFunction( std::vector expectedRecvData(10, recvFromIdx); // Do the sending - broker.sendMessage( - groupId, groupIdx, sendToIdx, sendData.data(), sendData.size()); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = sendToIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg); // Do the receiving - std::vector actualRecvData = - broker.recvMessage(groupId, recvFromIdx, groupIdx); + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = recvFromIdx, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg); + std::vector actualRecvData(recvMsg.dataSize); + std::memcpy(actualRecvData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of recvMessage + // to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); // Check data is as expected if (actualRecvData != expectedRecvData) { @@ -82,19 +95,31 @@ int handleManyPointToPointMsgFunction( // Send loop for (int i = 0; i < numMsg; i++) { std::vector sendData(5, i); - broker.sendMessage(groupId, - sendIdx, - recvIdx, - sendData.data(), - sendData.size(), - true); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg, true); } } else if (groupIdx == recvIdx) { // Recv loop for (int i = 0; i < numMsg; i++) { std::vector expectedData(5, i); - auto actualData = - broker.recvMessage(groupId, sendIdx, recvIdx, true); + + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg, true); + + std::vector actualData(recvMsg.dataSize); + std::memcpy(actualData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of + // recvMessage to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); + if (actualData != expectedData) { SPDLOG_ERROR( "Out-of-order message reception (got: {}, expected: {})", diff --git a/tests/dist/transport/test_point_to_point.cpp b/tests/dist/transport/test_point_to_point.cpp index 28ab5b6c7..8cd0f6e49 100644 --- a/tests/dist/transport/test_point_to_point.cpp +++ b/tests/dist/transport/test_point_to_point.cpp @@ -5,7 +5,6 @@ #include "init.h" #include -#include #include #include #include diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp index 98b16b9f7..6a23b90c5 100644 --- a/tests/test/transport/test_point_to_point.cpp +++ b/tests/test/transport/test_point_to_point.cpp @@ -120,9 +120,7 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector sentDataA = { 0, 1, 2, 3 }; std::vector receivedDataA; std::vector sentDataB = { 3, 4, 5 }; - std::vector receivedDataB; std::vector sentDataC = { 6, 7, 8 }; - std::vector receivedDataC; std::shared_ptr msgLatch = std::make_shared(2, 1000); @@ -131,34 +129,60 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, PointToPointBroker& broker = getPointToPointBroker(); // Receive the first message - receivedDataA = broker.recvMessage(groupId, idxA, idxB); + PointToPointMessage msgAB( + { .groupId = groupId, .sendIdx = idxA, .recvIdx = idxB }); + broker.recvMessage(msgAB); + receivedDataA.resize(msgAB.dataSize); + std::memcpy(receivedDataA.data(), msgAB.dataPtr, msgAB.dataSize); + faabric::util::free(msgAB.dataPtr); msgLatch->wait(); // Send a message back - broker.sendMessage( - groupId, idxB, idxA, sentDataB.data(), sentDataB.size()); + PointToPointMessage msgBA({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataB.size(), + .dataPtr = sentDataB.data() }); + broker.sendMessage(msgBA); // Lastly, send another message specifying the recepient host to avoid // an extra check in the broker - broker.sendMessage(groupId, - idxB, - idxA, - sentDataC.data(), - sentDataC.size(), - std::string(LOCALHOST)); + PointToPointMessage msgBA2({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataC.size(), + .dataPtr = sentDataC.data() }); + broker.sendMessage(msgBA2, std::string(LOCALHOST)); broker.resetThreadLocalCache(); }); // Only send the message after the thread creates a receiving socket to // avoid deadlock - broker.sendMessage(groupId, idxA, idxB, sentDataA.data(), sentDataA.size()); + PointToPointMessage msgAB({ .groupId = groupId, + .sendIdx = idxA, + .recvIdx = idxB, + .dataSize = sentDataA.size(), + .dataPtr = sentDataA.data() }); + broker.sendMessage(msgAB); // Wait for the thread to handle the message msgLatch->wait(); // Receive the two messages sent back - receivedDataB = broker.recvMessage(groupId, idxB, idxA); - receivedDataC = broker.recvMessage(groupId, idxB, idxA); + + PointToPointMessage msgBA1( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA1); + std::vector receivedDataB( + (uint8_t*)msgBA1.dataPtr, (uint8_t*)msgBA1.dataPtr + msgBA1.dataSize); + faabric::util::free(msgBA1.dataPtr); + + PointToPointMessage msgBA2( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA2); + std::vector receivedDataC( + (uint8_t*)msgBA2.dataPtr, (uint8_t*)msgBA2.dataPtr + msgBA2.dataSize); + faabric::util::free(msgBA2.dataPtr); if (t.joinable()) { t.join(); @@ -236,22 +260,28 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector recvData; for (int i = 0; i < numMsg; i++) { - recvData = - broker.recvMessage(groupId, idxA, idxB, isMessageOrderingOn); sendData = std::vector(3, i); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // TODO(no-order): when we remove the need to order PTP messages + // we will be able to provide a buffer to receive the message into + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(recvData == sendData); + faabric::util::free(msg.dataPtr); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxB, - idxA, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } broker.resetThreadLocalCache(); @@ -262,20 +292,26 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxA, - idxB, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - recvData = broker.recvMessage(groupId, idxB, idxA, isMessageOrderingOn); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // REQUIRE(msg.dataSize == recvData.size()); + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(sendData == recvData); + faabric::util::free(msg.dataPtr); } if (t.joinable()) { diff --git a/tests/test/transport/test_point_to_point_groups.cpp b/tests/test/transport/test_point_to_point_groups.cpp index 8d9761335..fa583e70b 100644 --- a/tests/test/transport/test_point_to_point_groups.cpp +++ b/tests/test/transport/test_point_to_point_groups.cpp @@ -135,8 +135,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, op = PointToPointCall::LOCK_GROUP; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, false); } @@ -147,8 +151,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, recursive = true; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, recursive); } @@ -166,8 +174,7 @@ TEST_CASE_METHOD(PointToPointGroupFixture, group->unlock(groupIdx, recursive); } - std::vector< - std::tuple> + std::vector> actualRequests = getSentLockMessages(); REQUIRE(actualRequests.size() == 1); @@ -176,11 +183,11 @@ TEST_CASE_METHOD(PointToPointGroupFixture, PointToPointCall actualOp = std::get<1>(actualRequests.at(0)); REQUIRE(actualOp == op); - faabric::PointToPointMessage req = std::get<2>(actualRequests.at(0)); - REQUIRE(req.appid() == appId); - REQUIRE(req.groupid() == groupId); - REQUIRE(req.sendidx() == groupIdx); - REQUIRE(req.recvidx() == POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req = std::get<2>(actualRequests.at(0)); + REQUIRE(req.appId == appId); + REQUIRE(req.groupId == groupId); + REQUIRE(req.sendIdx == groupIdx); + REQUIRE(req.recvIdx == POINT_TO_POINT_MAIN_IDX); } TEST_CASE_METHOD(PointToPointGroupFixture, diff --git a/tests/test/transport/test_point_to_point_message.cpp b/tests/test/transport/test_point_to_point_message.cpp new file mode 100644 index 000000000..51b1cb87c --- /dev/null +++ b/tests/test/transport/test_point_to_point_message.cpp @@ -0,0 +1,95 @@ +#include + +#include +#include + +#include + +using namespace faabric::transport; + +namespace tests { + +bool arePtpMsgEqual(const PointToPointMessage& msgA, const PointToPointMessage& msgB) +{ + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(PointToPointMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.dataPtr == nullptr || msgB.dataPtr == nullptr) { + return msgA.dataPtr == msgB.dataPtr; + } + + return std::memcmp(msgA.dataPtr, msgB.dataPtr, msgA.dataSize) == 0; +} + +TEST_CASE("Test (de)serialising a PTP message", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + SECTION("Empty message") + { + msg.dataSize = 0; + msg.dataPtr = nullptr; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + } + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + + if (msg.dataPtr != nullptr) { + faabric::util::free(msg.dataPtr); + } + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } +} + +TEST_CASE("Test (de)serialising a PTP message into prealloc buffer", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + std::vector preAllocBuffer(msg.dataSize); + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg, preAllocBuffer); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + REQUIRE(parsedMsg.dataPtr == preAllocBuffer.data()); + + faabric::util::free(msg.dataPtr); +} +}