Skip to content

Commit

Permalink
rocm updated graph api and fixed hlo_op_profiler_test
Browse files Browse the repository at this point in the history
  • Loading branch information
i-chaochen committed Aug 10, 2023
1 parent 8862bce commit 9ea7cbd
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 10 deletions.
6 changes: 3 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3154,15 +3154,15 @@ xla_cc_test(

xla_cc_test(
name = "hlo_op_profiler_test",
srcs = if_cuda_is_configured(["hlo_op_profiler_test.cc"]),
srcs = ["hlo_op_profiler_test.cc"],
tags = tf_cuda_tests_tags(),
deps = if_cuda_is_configured([
deps = [
":hlo_op_profiler_lib",
"//xla/hlo/ir:hlo",
"//xla/service:gpu_plugin",
"//xla/tests:hlo_test_base",
"@tsl//tsl/platform:test_main",
]),
],
)

cc_library(
Expand Down
176 changes: 169 additions & 7 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
bool FLAGS_gpuexec_rocm_device_0_only = false;

#define RETURN_IF_ROCM_ERROR(expr, ...) \
do { \
hipError_t _res = (expr); \
if (TF_PREDICT_FALSE(_res != hipSuccess)) { \
return tsl::errors::Internal(__VA_ARGS__, ": ", \
::stream_executor::gpu::ToString(_res)); \
} \
#define RETURN_IF_ROCM_ERROR(expr, ...) \
do { \
hipError_t _res = (expr); \
if (TF_PREDICT_FALSE(_res != hipSuccess)) { \
if (_res == hipErrorOutOfMemory) \
return tsl::errors::ResourceExhausted( \
__VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res)); \
else \
return tsl::errors::Internal(__VA_ARGS__, ": ", \
::stream_executor::gpu::ToString(_res));\
} \
} while (0)

// Debugging: on each push and pop of a rocm context, verify the current device
Expand Down Expand Up @@ -396,6 +400,164 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
return tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::CreateGraph(hipGraph_t* graph) {
VLOG(2) << "Create new HIP graph";
RETURN_IF_ROCM_ERROR(hipGraphCreate(graph, /*flags=*/0),
"Failed to create HIP graph");
VLOG(2) << "Created HIP graph " << graph;
return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::DestroyGraph(hipGraph_t graph) {
VLOG(2) << "Destroy HIP graph " << graph;
RETURN_IF_ROCM_ERROR(hipGraphDestroy(graph),
"Failed to destroy HIP graph");
return ::tsl::OkStatus();
}

static std::string_view StreamCaptureModeToString(
GpuDriver::StreamCaptureMode mode) {
switch (mode) {
case GpuDriver::StreamCaptureMode::kGlobal:
return "global";
case GpuDriver::StreamCaptureMode::kThreadLocal:
return "threadlocal";
case GpuDriver::StreamCaptureMode::kRelaxed:
return "relaxed";
}
}

/* static */ tsl::Status GpuDriver::StreamBeginCapture(GpuStreamHandle stream,
StreamCaptureMode mode) {
hipStreamCaptureMode hip_mode;
switch (mode) {
case StreamCaptureMode::kGlobal:
hip_mode = hipStreamCaptureModeGlobal;
break;
case StreamCaptureMode::kThreadLocal:
hip_mode = hipStreamCaptureModeThreadLocal;
break;
case StreamCaptureMode::kRelaxed:
hip_mode = hipStreamCaptureModeRelaxed;
break;
}

VLOG(2) << "Beging stream " << stream << " capture in "
<< StreamCaptureModeToString(mode) << " mode";
RETURN_IF_ROCM_ERROR(hipStreamBeginCapture(stream, hip_mode),
"Failed to begin stream capture");
return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::StreamEndCapture(GpuStreamHandle stream,
hipGraph_t* graph) {
VLOG(2) << "End stream " << stream << " capture";

RETURN_IF_ROCM_ERROR(hipStreamEndCapture(stream, graph),
"Failed to end stream capture");

return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::GraphInstantiate(
hipGraphExec_t* exec, hipGraph_t graph, const GraphInstantiateFlags& flags) {
VLOG(2) << "Instante HIP executable graph from graph " << graph << " ("
<< "auto_free_on_launch=" << flags.auto_free_on_launch << ", "
<< "device_launch=" << flags.device_launch << ", "
<< "use_node_priority=" << flags.use_node_prirotiy << ", "
<< "upload=" << flags.upload << ")";
RETURN_IF_ROCM_ERROR(hipGraphInstantiate(exec, graph, nullptr, nullptr, 0),
"Failed to instantiate HIP graph");
return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::GraphLaunch(hipGraphExec_t exec,
GpuStreamHandle stream) {
VLOG(2) << "Launching HIP executable graph " << exec << " on a stream "
<< stream;
RETURN_IF_ROCM_ERROR(hipGraphLaunch(exec, stream),
"Failed to launch HIP graph");
return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::GraphExecUpdate(
hipGraphExec_t exec, hipGraph_t graph, GraphExecUpdateResultInfo* result) {
VLOG(2) << "Update HIP graph executable " << exec << " with graph " << graph;

hipGraphExecUpdateResult hip_result;
RETURN_IF_ROCM_ERROR(hipGraphExecUpdate(exec, graph, nullptr, &hip_result),
"Failed to update HIP graph");
auto hip_result_enum = hip_result;

switch (hip_result_enum) {
case hipGraphExecUpdateSuccess:
result->result = GraphExecUpdateResult::kSuccess;
break;
case hipGraphExecUpdateError:
result->result = GraphExecUpdateResult::kError;
break;
case hipGraphExecUpdateErrorTopologyChanged:
result->result = GraphExecUpdateResult::kTopologyChanged;
break;
case hipGraphExecUpdateErrorNodeTypeChanged:
result->result = GraphExecUpdateResult::kNodeTypeChanged;
break;
case hipGraphExecUpdateErrorFunctionChanged:
result->result = GraphExecUpdateResult::kFunctionChanged;
break;
case hipGraphExecUpdateErrorParametersChanged:
result->result = GraphExecUpdateResult::kParametersChanged;
break;
case hipGraphExecUpdateErrorNotSupported:
result->result = GraphExecUpdateResult::kNotSupported;
break;
case hipGraphExecUpdateErrorUnsupportedFunctionChange:
result->result = GraphExecUpdateResult::kUnsupportedFunctionChange;
break;
// TODO: HIP hasn't GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED yet
}

return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::DestroyGraphExec(hipGraphExec_t exec) {
VLOG(2) << "Destroying HIP executable graph" << exec;
RETURN_IF_ROCM_ERROR(hipGraphExecDestroy(exec),
"Failed to destroy HIP graph");
return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::GraphDebugDotPrint(hipGraph_t graph,
const char* path) {
VLOG(2) << "Print HIP graph " << graph << " debug dot file to " << path;

int flags = hipGraphDebugDotFlagsVerbose;
RETURN_IF_ROCM_ERROR(hipGraphDebugDotPrint(graph, path, flags),
"Failed to print gpu graph debug file");

if (VLOG_IS_ON(100)) {
std::string data;
if (tsl::ReadFileToString(tsl::Env::Default(), path, &data).ok()) {
VLOG(200) << "HIP graph " << graph << " debug file:\n" << data;
} else {
LOG(WARNING) << "failed to read gpu graph debug file " << path;
}
}

return ::tsl::OkStatus();
}

/* static */ tsl::StatusOr<bool> GpuDriver::StreamIsCapturing(GpuStreamHandle stream) {
VLOG(2) << "Checking if stream " << stream << " is capturing";

hipStreamCaptureStatus status;
RETURN_IF_ROCM_ERROR(hipStreamIsCapturing(stream, &status),
"Failed to check stream capturing status");

return status == hipStreamCaptureStatusActive;
}


/* static */ tsl::Status GpuDriver::LaunchKernel(
GpuContext* context, absl::string_view kernel_name, hipFunction_t function,
unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
Expand Down
2 changes: 2 additions & 0 deletions xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ namespace wrap {
__macro(hipGetDeviceProperties) \
__macro(hipGetErrorString) \
__macro(hipGraphDebugDotPrint) \
__macro(hipGraphDebugDotFlagsVerbose) \
__macro(hipGraphDestroy) \
__macro(hipGraphExecDestroy) \
__macro(hipGraphExecUpdate) \
__macro(hipGraphInstantiate) \
__macro(hipGraphLaunch) \
__macro(hipGraphCreate) \
__macro(hipHostFree) \
__macro(hipHostMalloc) \
__macro(hipHostRegister) \
Expand Down

0 comments on commit 9ea7cbd

Please sign in to comment.