Skip to content

Commit

Permalink
[XLA:CPU] Update RunHloBenchmark to enable running with HLO with infe…
Browse files Browse the repository at this point in the history
…rred arguments

PiperOrigin-RevId: 698740778
  • Loading branch information
Google-ML-Automation committed Nov 22, 2024
1 parent e92f516 commit 60ae8eb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
1 change: 1 addition & 0 deletions xla/service/cpu/benchmarks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cc_library(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/cpu:cpu_client",
"//xla/service:hlo_module_config",
"//xla/tests:test_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down
34 changes: 29 additions & 5 deletions xla/service/cpu/benchmarks/hlo_benchmark_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h"

#include <cstddef>
#include <memory>
#include <string_view>
#include <vector>
Expand All @@ -30,6 +31,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tests/test_utils.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test_benchmark.h"
Expand Down Expand Up @@ -63,12 +65,34 @@ absl::Status RunHloBenchmark(benchmark::State& state,

// Convert literals to PjRtBuffers.
std::vector<std::unique_ptr<PjRtBuffer>> args_buffers;
args_buffers.reserve(args.size());

for (const Literal* arg : args) {
TF_ASSIGN_OR_RETURN(args_buffers.emplace_back(),
client->BufferFromHostLiteral(*arg, device));
TF_RETURN_IF_ERROR(args_buffers.back()->GetReadyFuture().Await());
size_t expected_arg_count =
module->entry_computation()->parameter_instructions().size();

// If the user has not passed any arguments we need to generate
// fake arguments based on the number of inputs to the hlo module.
if (args.empty()) {
TF_ASSIGN_OR_RETURN(std::vector<Literal> fake_args,
MakeFakeArguments(module.get()));
args_buffers.reserve(fake_args.size());
for (const Literal& arg : fake_args) {
TF_ASSIGN_OR_RETURN(args_buffers.emplace_back(),
client->BufferFromHostLiteral(arg, device));
TF_RETURN_IF_ERROR(args_buffers.back()->GetReadyFuture().Await());
}
} else {
if (expected_arg_count != args.size()) {
return absl::InvalidArgumentError(
"Number of arguments does not match the number of parameters in "
"the HLO module.");
}

args_buffers.reserve(args.size());
for (const Literal* arg : args) {
TF_ASSIGN_OR_RETURN(args_buffers.emplace_back(),
client->BufferFromHostLiteral(*arg, device));
TF_RETURN_IF_ERROR(args_buffers.back()->GetReadyFuture().Await());
}
}

// Execute in synchronous mode to avoid thread hops.
Expand Down

0 comments on commit 60ae8eb

Please sign in to comment.