Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: inline small messages #409

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions include/faabric/mpi/MpiMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
#include <cstdint>
#include <vector>

// Constant copied from OpenMPI's SM implementation. It indicates the maximum
// number of Bytes that we may inline in a message (rather than malloc-ing)
// https://github.com/open-mpi/ompi/blob/main/opal/mca/btl/sm/btl_sm_component.c#L153
#define MPI_MAX_INLINE_SEND 256

namespace faabric::mpi {

enum MpiMessageType : int32_t
Expand Down Expand Up @@ -49,7 +54,11 @@ struct MpiMessage
// struct 8-aligned
int32_t requestId;
MpiMessageType messageType;
void* buffer;
union
{
void* buffer;
uint8_t inlineMsg[MPI_MAX_INLINE_SEND];
};
};
static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!");

Expand All @@ -60,7 +69,14 @@ inline size_t payloadSize(const MpiMessage& msg)

inline size_t msgSize(const MpiMessage& msg)
{
return sizeof(MpiMessage) + payloadSize(msg);
size_t payloadSz = payloadSize(msg);

// If we can inline the message, we do not need to add anything else
if (payloadSz < MPI_MAX_INLINE_SEND) {
return sizeof(MpiMessage);
}

return sizeof(MpiMessage) + payloadSz;
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);
Expand Down
13 changes: 7 additions & 6 deletions src/mpi/MpiMessage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@ void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
assert(msg != nullptr);
assert(bytes.size() >= sizeof(MpiMessage));
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
assert(thisPayloadSize == payloadSize(*msg));
size_t thisPayloadSize = payloadSize(*msg);

if (thisPayloadSize == 0) {
msg->buffer = nullptr;
return;
}

msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
if (thisPayloadSize > MPI_MAX_INLINE_SEND) {
msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
}
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
{
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
size_t payloadSz = payloadSize(msg);
if (payloadSz > 0 && msg.buffer != nullptr) {
if (payloadSz > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
}
}
Expand Down
70 changes: 50 additions & 20 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,7 @@ void MpiWorld::send(int sendRank,
.recvRank = recvRank,
.typeSize = dataType->size,
.count = count,
.messageType = messageType,
.buffer = nullptr };
.messageType = messageType };

// Mock the message sending in tests
#ifndef NDEBUG
Expand All @@ -591,25 +590,36 @@ void MpiWorld::send(int sendRank,
}
#endif

bool mustSendData = count > 0 && buffer != nullptr;
size_t dataSize = count * dataType->size;
bool mustSendData = dataSize > 0 && buffer != nullptr;

// Dispatch the message locally or globally
if (isLocal) {
// Take control over the buffer data if we are gonna move it to
// the in-memory queues for local messaging
if (mustSendData) {
void* bufferPtr = faabric::util::malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count * dataType->size);
if (dataSize < MPI_MAX_INLINE_SEND) {
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
} else {
void* bufferPtr = faabric::util::malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count * dataType->size);

msg.buffer = bufferPtr;
msg.buffer = bufferPtr;
}
} else {
msg.buffer = nullptr;
}

SPDLOG_TRACE(
"MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
getLocalQueue(sendRank, recvRank)->enqueue(msg);
} else {
if (mustSendData) {
msg.buffer = (void*)buffer;
if (dataSize < MPI_MAX_INLINE_SEND) {
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
} else {
msg.buffer = (void*)buffer;
}
}

SPDLOG_TRACE(
Expand Down Expand Up @@ -691,19 +701,28 @@ void MpiWorld::doRecv(const MpiMessage& msg,
msg.messageType,
messageType);
}
assert(msg.messageType == messageType);
assert(msg.count <= count);

assert(m.messageType == messageType);
assert(m.count <= count);
size_t dataSize = msg.count * dataType->size;

// We must copy the data into the application-provided buffer
if (msg.count > 0 && msg.buffer != nullptr) {
if (dataSize > 0) {
// Make sure we do not overflow the recepient buffer
auto bytesToCopy =
std::min<size_t>(msg.count * dataType->size, count * dataType->size);
std::memcpy(buffer, msg.buffer, bytesToCopy);

// This buffer has been malloc-ed either as part of a local `send`
// or as part of a remote `parseMpiMsg`
faabric::util::free((void*)msg.buffer);
if (dataSize > MPI_MAX_INLINE_SEND) {
assert(m.buffer != nullptr);

std::memcpy(buffer, msg.buffer, bytesToCopy);

// This buffer has been malloc-ed either as part of a local `send`
// or as part of a remote `parseMpiMsg`
faabric::util::free((void*)msg.buffer);
} else {
std::memcpy(buffer, msg.inlineMsg, bytesToCopy);
}
}

// Set status values if required
Expand Down Expand Up @@ -1923,21 +1942,32 @@ MpiMessage MpiWorld::recvBatchReturnLast(int sendRank,
// Copy the request id so that it is not overwritten
int tmpRequestId = itr->requestId;

// Copy into current slot in the list, but keep a copy to the
// app-provided buffer to read data into
// Copy the app-provided buffer to recv data into so that it is
// not overwritten too. Note that, irrespective of wether the
// message is inlined or not, we always use the buffer pointer to
// point to the app-provided recv-buffer
void* providedBuffer = itr->buffer;
*itr = internalRecv(sendRank, recvRank, isLocal);
itr->requestId = tmpRequestId;

if (itr->buffer != nullptr) {
// If we have send a non-inlined message, copy the data into the
// provided buffer, free the one in the queue,
size_t dataSize = itr->count * itr->typeSize;
if (dataSize > MPI_MAX_INLINE_SEND) {
assert(itr->buffer != nullptr);
assert(providedBuffer != nullptr);
// If buffers are not null, we must have a non-zero size
assert((itr->count * itr->typeSize) > 0);
std::memcpy(
providedBuffer, itr->buffer, itr->count * itr->typeSize);

faabric::util::free(itr->buffer);

itr->buffer = providedBuffer;
} else if (dataSize > 0) {
std::memcpy(
providedBuffer, itr->inlineMsg, itr->count * itr->typeSize);
} else {
itr->buffer = providedBuffer;
}
itr->buffer = providedBuffer;
}
assert(itr->messageType != MpiMessageType::UNACKED_MPI_MESSAGE);

Expand Down
3 changes: 0 additions & 3 deletions tests/dist/mpi/examples/mpi_isendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ int iSendRecv()
}
printf("Rank %i - async working properly\n", rank);

delete sendRequest;
delete recvRequest;

MPI_Finalize();

return 0;
Expand Down
3 changes: 0 additions & 3 deletions tests/dist/mpi/examples/mpi_send_sync_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ int sendSyncAsync()
MPI_Send(&r, 1, MPI_INT, r, 0, MPI_COMM_WORLD);
MPI_Wait(&sendRequest, MPI_STATUS_IGNORE);
}
delete sendRequest;
} else {
// Asynchronously receive twice from rank 0
int recvValue1 = -1;
Expand All @@ -47,8 +46,6 @@ int sendSyncAsync()
rank);
return 1;
}
delete recvRequest1;
delete recvRequest2;
}
printf("Rank %i - send sync and async working properly\n", rank);

Expand Down
53 changes: 41 additions & 12 deletions tests/test/mpi/test_mpi_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
return false;
}

// First, compare the message body (excluding the pointer, which we
// know is at the end)
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) {
// First, compare the message body (excluding the union at the end)
size_t unionSize = sizeof(uint8_t) * MPI_MAX_INLINE_SEND;
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - unionSize) != 0) {
return false;
}

Expand All @@ -35,7 +35,11 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
// Assert, as this should pass given the previous comparisons
assert(payloadSizeA == payloadSizeB);

return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
if (payloadSizeA > MPI_MAX_INLINE_SEND) {
return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
}

return std::memcmp(msgA.inlineMsg, msgB.inlineMsg, payloadSizeA) == 0;
}

TEST_CASE("Test getting a message size", "[mpi]")
Expand All @@ -59,11 +63,23 @@ TEST_CASE("Test getting a message size", "[mpi]")
expectedPayloadSize = 0;
}

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
msg.count = nums.size();
msg.typeSize = sizeof(int);
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));

expectedPayloadSize = sizeof(int) * nums.size();
expectedMsgSize = sizeof(MpiMessage);
}

SECTION("Non-empty (large) message")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
std::vector<int32_t> nums(maxNumInts + 3, 3);
msg.count = nums.size();
msg.typeSize = sizeof(int);
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));

Expand All @@ -74,7 +90,7 @@ TEST_CASE("Test getting a message size", "[mpi]")
REQUIRE(expectedMsgSize == msgSize(msg));
REQUIRE(expectedPayloadSize == payloadSize(msg));

if (msg.buffer != nullptr) {
if (expectedPayloadSize > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
}
Expand All @@ -95,11 +111,22 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")
msg.buffer = nullptr;
}

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
msg.count = nums.size();
msg.typeSize = sizeof(int);
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));
}

SECTION("Non-empty (large) message")
{
// Make sure we send more ints than the maximum inline
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
std::vector<int32_t> nums(maxNumInts + 3, 3);
msg.count = nums.size();
msg.typeSize = sizeof(int);
REQUIRE(payloadSize(msg) > MPI_MAX_INLINE_SEND);
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
}
Expand All @@ -113,11 +140,13 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")

REQUIRE(areMpiMsgEqual(msg, parsedMsg));

if (msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
if (parsedMsg.buffer != nullptr) {
faabric::util::free(parsedMsg.buffer);
if (msg.count * msg.typeSize > MPI_MAX_INLINE_SEND) {
if (msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
if (parsedMsg.buffer != nullptr) {
faabric::util::free(parsedMsg.buffer);
}
}
}
}
31 changes: 28 additions & 3 deletions tests/test/mpi/test_mpi_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,17 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send and recv on same host", "[mpi]")
int rankA2 = 1;
std::vector<int> messageData;

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
messageData = { 0, 1, 2 };
}

SECTION("Non-empty (large) message")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
messageData = std::vector<int>(maxNumInts + 3, 3);
}

SECTION("Empty message")
{
messageData = {};
Expand Down Expand Up @@ -273,8 +279,27 @@ TEST_CASE_METHOD(MpiTestFixture, "Test sendrecv", "[mpi]")
int rankA = 1;
int rankB = 2;
MPI_Status status{};
std::vector<int> messageDataAB = { 0, 1, 2 };
std::vector<int> messageDataBA = { 3, 2, 1, 0 };
std::vector<int> messageDataAB;
std::vector<int> messageDataBA;

SECTION("Empty messages")
{
messageDataAB = {};
messageDataBA = {};
}

SECTION("Small messages")
{
messageDataAB = { 0, 1, 2 };
messageDataBA = { 3, 2, 1, 0 };
}

SECTION("Large messages")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
messageDataAB = std::vector<int>(maxNumInts + 3, 3);
messageDataBA = std::vector<int>(maxNumInts + 4, 4);
}

// Results
std::vector<int> recvBufferA(messageDataBA.size(), 0);
Expand Down
Loading