diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h index f24f37d7f..95dabba1f 100644 --- a/include/faabric/mpi/MpiMessage.h +++ b/include/faabric/mpi/MpiMessage.h @@ -3,6 +3,11 @@ #include #include +// Constant copied from OpenMPI's SM implementation. It indicates the maximum +// number of Bytes that we may inline in a message (rather than malloc-ing) +// https://github.com/open-mpi/ompi/blob/main/opal/mca/btl/sm/btl_sm_component.c#L153 +#define MPI_MAX_INLINE_SEND 256 + namespace faabric::mpi { enum MpiMessageType : int32_t @@ -49,7 +54,11 @@ struct MpiMessage // struct 8-aligned int32_t requestId; MpiMessageType messageType; - void* buffer; + union + { + void* buffer; + uint8_t inlineMsg[MPI_MAX_INLINE_SEND]; + }; }; static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!"); @@ -60,7 +69,14 @@ inline size_t payloadSize(const MpiMessage& msg) inline size_t msgSize(const MpiMessage& msg) { - return sizeof(MpiMessage) + payloadSize(msg); + size_t payloadSz = payloadSize(msg); + + // If we can inline the message, we do not need to add anything else + if (payloadSz < MPI_MAX_INLINE_SEND) { + return sizeof(MpiMessage); + } + + return sizeof(MpiMessage) + payloadSz; } void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg); diff --git a/src/mpi/MpiMessage.cpp b/src/mpi/MpiMessage.cpp index 57ee8c85e..dee2366a1 100644 --- a/src/mpi/MpiMessage.cpp +++ b/src/mpi/MpiMessage.cpp @@ -12,24 +12,25 @@ void parseMpiMsg(const std::vector& bytes, MpiMessage* msg) assert(msg != nullptr); assert(bytes.size() >= sizeof(MpiMessage)); std::memcpy(msg, bytes.data(), sizeof(MpiMessage)); - size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage); - assert(thisPayloadSize == payloadSize(*msg)); + size_t thisPayloadSize = payloadSize(*msg); if (thisPayloadSize == 0) { msg->buffer = nullptr; return; } - msg->buffer = faabric::util::malloc(thisPayloadSize); - std::memcpy( - msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize); + if (thisPayloadSize > MPI_MAX_INLINE_SEND) { + msg->buffer = faabric::util::malloc(thisPayloadSize); + std::memcpy( + msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize); + } } void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg) { std::memcpy(buffer.data(), &msg, sizeof(MpiMessage)); size_t payloadSz = payloadSize(msg); - if (payloadSz > 0 && msg.buffer != nullptr) { + if (payloadSz > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) { std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz); } } diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index d27b259f5..323fc2809 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -580,8 +580,7 @@ void MpiWorld::send(int sendRank, .recvRank = recvRank, .typeSize = dataType->size, .count = count, - .messageType = messageType, - .buffer = nullptr }; + .messageType = messageType }; // Mock the message sending in tests #ifndef NDEBUG @@ -591,17 +590,24 @@ void MpiWorld::send(int sendRank, } #endif - bool mustSendData = count > 0 && buffer != nullptr; + size_t dataSize = count * dataType->size; + bool mustSendData = dataSize > 0 && buffer != nullptr; // Dispatch the message locally or globally if (isLocal) { // Take control over the buffer data if we are gonna move it to // the in-memory queues for local messaging if (mustSendData) { - void* bufferPtr = faabric::util::malloc(count * dataType->size); - std::memcpy(bufferPtr, buffer, count * dataType->size); + if (dataSize < MPI_MAX_INLINE_SEND) { + std::memcpy(msg.inlineMsg, buffer, count * dataType->size); + } else { + void* bufferPtr = faabric::util::malloc(count * dataType->size); + std::memcpy(bufferPtr, buffer, count * dataType->size); - msg.buffer = bufferPtr; + msg.buffer = bufferPtr; + } + } else { + msg.buffer = nullptr; } SPDLOG_TRACE( @@ -609,7 +615,11 @@ void MpiWorld::send(int sendRank, getLocalQueue(sendRank, recvRank)->enqueue(msg); } else { if (mustSendData) { - msg.buffer = (void*)buffer; + if (dataSize < MPI_MAX_INLINE_SEND) { + std::memcpy(msg.inlineMsg, buffer, count * dataType->size); + } else { + msg.buffer = (void*)buffer; + } } SPDLOG_TRACE( @@ -691,19 +701,28 @@ void MpiWorld::doRecv(const MpiMessage& msg, msg.messageType, messageType); } - assert(msg.messageType == messageType); - assert(msg.count <= count); + + assert(m.messageType == messageType); + assert(m.count <= count); + size_t dataSize = msg.count * dataType->size; // We must copy the data into the application-provided buffer - if (msg.count > 0 && msg.buffer != nullptr) { + if (dataSize > 0) { // Make sure we do not overflow the recepient buffer auto bytesToCopy = std::min(msg.count * dataType->size, count * dataType->size); - std::memcpy(buffer, msg.buffer, bytesToCopy); - // This buffer has been malloc-ed either as part of a local `send` - // or as part of a remote `parseMpiMsg` - faabric::util::free((void*)msg.buffer); + if (dataSize > MPI_MAX_INLINE_SEND) { + assert(m.buffer != nullptr); + + std::memcpy(buffer, msg.buffer, bytesToCopy); + + // This buffer has been malloc-ed either as part of a local `send` + // or as part of a remote `parseMpiMsg` + faabric::util::free((void*)msg.buffer); + } else { + std::memcpy(buffer, msg.inlineMsg, bytesToCopy); + } } // Set status values if required @@ -1923,21 +1942,32 @@ MpiMessage MpiWorld::recvBatchReturnLast(int sendRank, // Copy the request id so that it is not overwritten int tmpRequestId = itr->requestId; - // Copy into current slot in the list, but keep a copy to the - // app-provided buffer to read data into + // Copy the app-provided buffer to recv data into so that it is + // not overwritten too. Note that, irrespective of wether the + // message is inlined or not, we always use the buffer pointer to + // point to the app-provided recv-buffer void* providedBuffer = itr->buffer; *itr = internalRecv(sendRank, recvRank, isLocal); itr->requestId = tmpRequestId; - if (itr->buffer != nullptr) { + // If we have send a non-inlined message, copy the data into the + // provided buffer, free the one in the queue, + size_t dataSize = itr->count * itr->typeSize; + if (dataSize > MPI_MAX_INLINE_SEND) { + assert(itr->buffer != nullptr); assert(providedBuffer != nullptr); - // If buffers are not null, we must have a non-zero size - assert((itr->count * itr->typeSize) > 0); std::memcpy( providedBuffer, itr->buffer, itr->count * itr->typeSize); + faabric::util::free(itr->buffer); + + itr->buffer = providedBuffer; + } else if (dataSize > 0) { + std::memcpy( + providedBuffer, itr->inlineMsg, itr->count * itr->typeSize); + } else { + itr->buffer = providedBuffer; } - itr->buffer = providedBuffer; } assert(itr->messageType != MpiMessageType::UNACKED_MPI_MESSAGE); diff --git a/tests/dist/mpi/examples/mpi_isendrecv.cpp b/tests/dist/mpi/examples/mpi_isendrecv.cpp index 97ea79be4..79d0d43c6 100644 --- a/tests/dist/mpi/examples/mpi_isendrecv.cpp +++ b/tests/dist/mpi/examples/mpi_isendrecv.cpp @@ -41,9 +41,6 @@ int iSendRecv() } printf("Rank %i - async working properly\n", rank); - delete sendRequest; - delete recvRequest; - MPI_Finalize(); return 0; diff --git a/tests/dist/mpi/examples/mpi_send_sync_async.cpp b/tests/dist/mpi/examples/mpi_send_sync_async.cpp index d7f8ed7c9..8c3ac7e7b 100644 --- a/tests/dist/mpi/examples/mpi_send_sync_async.cpp +++ b/tests/dist/mpi/examples/mpi_send_sync_async.cpp @@ -22,7 +22,6 @@ int sendSyncAsync() MPI_Send(&r, 1, MPI_INT, r, 0, MPI_COMM_WORLD); MPI_Wait(&sendRequest, MPI_STATUS_IGNORE); } - delete sendRequest; } else { // Asynchronously receive twice from rank 0 int recvValue1 = -1; @@ -47,8 +46,6 @@ int sendSyncAsync() rank); return 1; } - delete recvRequest1; - delete recvRequest2; } printf("Rank %i - send sync and async working properly\n", rank); diff --git a/tests/test/mpi/test_mpi_message.cpp b/tests/test/mpi/test_mpi_message.cpp index c1051cb84..8a8a0f1d9 100644 --- a/tests/test/mpi/test_mpi_message.cpp +++ b/tests/test/mpi/test_mpi_message.cpp @@ -18,9 +18,9 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB) return false; } - // First, compare the message body (excluding the pointer, which we - // know is at the end) - if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) { + // First, compare the message body (excluding the union at the end) + size_t unionSize = sizeof(uint8_t) * MPI_MAX_INLINE_SEND; + if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - unionSize) != 0) { return false; } @@ -35,7 +35,11 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB) // Assert, as this should pass given the previous comparisons assert(payloadSizeA == payloadSizeB); - return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0; + if (payloadSizeA > MPI_MAX_INLINE_SEND) { + return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0; + } + + return std::memcmp(msgA.inlineMsg, msgB.inlineMsg, payloadSizeA) == 0; } TEST_CASE("Test getting a message size", "[mpi]") @@ -59,11 +63,23 @@ TEST_CASE("Test getting a message size", "[mpi]") expectedPayloadSize = 0; } - SECTION("Non-empty message") + SECTION("Non-empty (small) message") { std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; msg.count = nums.size(); msg.typeSize = sizeof(int); + std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int)); + + expectedPayloadSize = sizeof(int) * nums.size(); + expectedMsgSize = sizeof(MpiMessage); + } + + SECTION("Non-empty (large) message") + { + int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t); + std::vector nums(maxNumInts + 3, 3); + msg.count = nums.size(); + msg.typeSize = sizeof(int); msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); @@ -74,7 +90,7 @@ TEST_CASE("Test getting a message size", "[mpi]") REQUIRE(expectedMsgSize == msgSize(msg)); REQUIRE(expectedPayloadSize == payloadSize(msg)); - if (msg.buffer != nullptr) { + if (expectedPayloadSize > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) { faabric::util::free(msg.buffer); } } @@ -95,11 +111,22 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]") msg.buffer = nullptr; } - SECTION("Non-empty message") + SECTION("Non-empty (small) message") { std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; msg.count = nums.size(); msg.typeSize = sizeof(int); + std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int)); + } + + SECTION("Non-empty (large) message") + { + // Make sure we send more ints than the maximum inline + int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t); + std::vector nums(maxNumInts + 3, 3); + msg.count = nums.size(); + msg.typeSize = sizeof(int); + REQUIRE(payloadSize(msg) > MPI_MAX_INLINE_SEND); msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); } @@ -113,11 +140,13 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]") REQUIRE(areMpiMsgEqual(msg, parsedMsg)); - if (msg.buffer != nullptr) { - faabric::util::free(msg.buffer); - } - if (parsedMsg.buffer != nullptr) { - faabric::util::free(parsedMsg.buffer); + if (msg.count * msg.typeSize > MPI_MAX_INLINE_SEND) { + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } + if (parsedMsg.buffer != nullptr) { + faabric::util::free(parsedMsg.buffer); + } } } } diff --git a/tests/test/mpi/test_mpi_world.cpp b/tests/test/mpi/test_mpi_world.cpp index b32aab407..6aa87c5b7 100644 --- a/tests/test/mpi/test_mpi_world.cpp +++ b/tests/test/mpi/test_mpi_world.cpp @@ -239,11 +239,17 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send and recv on same host", "[mpi]") int rankA2 = 1; std::vector messageData; - SECTION("Non-empty message") + SECTION("Non-empty (small) message") { messageData = { 0, 1, 2 }; } + SECTION("Non-empty (large) message") + { + int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t); + messageData = std::vector(maxNumInts + 3, 3); + } + SECTION("Empty message") { messageData = {}; @@ -273,8 +279,27 @@ TEST_CASE_METHOD(MpiTestFixture, "Test sendrecv", "[mpi]") int rankA = 1; int rankB = 2; MPI_Status status{}; - std::vector messageDataAB = { 0, 1, 2 }; - std::vector messageDataBA = { 3, 2, 1, 0 }; + std::vector messageDataAB; + std::vector messageDataBA; + + SECTION("Empty messages") + { + messageDataAB = {}; + messageDataBA = {}; + } + + SECTION("Small messages") + { + messageDataAB = { 0, 1, 2 }; + messageDataBA = { 3, 2, 1, 0 }; + } + + SECTION("Large messages") + { + int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t); + messageDataAB = std::vector(maxNumInts + 3, 3); + messageDataBA = std::vector(maxNumInts + 4, 4); + } // Results std::vector recvBufferA(messageDataBA.size(), 0);