Skip to content

Commit

Permalink
[ROCm] Fix pjrt topology code to support ROCm
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Jul 1, 2024
1 parent 153798d commit a6d58eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
9 changes: 8 additions & 1 deletion xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,16 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create(
/*num_slices=*/1,
/*num_hosts_per_slice=*/1,
/*num_devices_per_host=*/device_ids.size());

// Determine the platform ID and name based on the platform.
xla::PjRtPlatformId platform_id =
(std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmId() : xla::CudaId();
std::string platform_name =
(std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmName() : xla::CudaName();

auto pjrt_topology =
std::make_unique<xla::StreamExecutorGpuTopologyDescription>(
xla::CudaId(), xla::CudaName(), std::move(gpu_topology),
platform_id, platform_name, std::move(gpu_topology),
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>{
{"target_config",
gpu_target_config.ToProto().SerializeAsString()}});
Expand Down
31 changes: 31 additions & 0 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_internal.h"

namespace pjrt {
namespace {
Expand Down Expand Up @@ -409,6 +410,36 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) {
api->PJRT_Error_Destroy(&error_destroy_args);
}

TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) {
auto pjrt_api = gpu_plugin::GetGpuPjrtApi();

PJRT_TopologyDescription_Create_Args args;
args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.topology = nullptr;

PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args);
EXPECT_EQ(error, nullptr) << error->status.message();

auto pjrt_topology = reinterpret_cast<const PJRT_TopologyDescription*>(args.topology);
ASSERT_NE(pjrt_topology, nullptr);

#ifdef TENSORFLOW_USE_ROCM
EXPECT_EQ(pjrt_topology->topology->platform_id(), xla::RocmId());
EXPECT_EQ(pjrt_topology->topology->platform_name(), xla::RocmName());
#else
EXPECT_EQ(pjrt_topology->topology->platform_id(), xla::CudaId());
EXPECT_EQ(pjrt_topology->topology->platform_name(), xla::CudaName());
#endif

PJRT_TopologyDescription_Destroy_Args destroy_args;
destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE;
destroy_args.extension_start = nullptr;
destroy_args.topology = const_cast<PJRT_TopologyDescription*>(pjrt_topology);
PJRT_Error* destroy_error = pjrt_api->PJRT_TopologyDescription_Destroy(&destroy_args);
EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message();
}

void TestCustomCallV2() {}

TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) {
Expand Down

0 comments on commit a6d58eb

Please sign in to comment.