[XLA] Switch replay_computation to use LocalClient.

This lets replay_computation build an executable once and run it
multiple times.  This is particularly important because in XLA:GPU, the
first run of an executable does some autotuning and therefore is
unrepresentative.

This change removes --xla_hlo_profile_last_run, because I don't see how
to support it in LocalClient -- LocalClient wants the do-profile bit to
be set when we *compile*.  (There may not be an easy fix for this; it
worked with regular Client because we were recompiling every time we
ran.)

PiperOrigin-RevId: 198643577
This commit is contained in:
Justin Lebar 2018-05-30 17:00:50 -07:00 committed by TensorFlower Gardener
parent 0895714301
commit 49535c9da6
5 changed files with 75 additions and 41 deletions

View File

@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
shaped_buffer);
}
StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number) {
return local_service_->GlobalDataToShapedBuffer(data, replica_number);
}
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,

View File

@ -136,6 +136,11 @@ class LocalClient : public Client {
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number);
// Transfer the given literal to the infeed queue of the given device.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with

View File

@ -260,4 +260,15 @@ StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
/*computation_count=*/1);
}
StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number) {
TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
if (replica_number >= buffers.size()) {
return InvalidArgument(
"replica_number %d out of range; must be less than num_replicas = %zu.",
replica_number, buffers.size());
}
return buffers[replica_number];
}
} // namespace xla

View File

@ -70,6 +70,11 @@ class LocalService : public Service {
// the "easy" case where a single replica is a single device.
StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number);
private:
explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend);

View File

@ -68,7 +68,6 @@ struct Options {
bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
bool xla_hlo_profile_last_run = false;
};
// Invokes the given computation passing arbitrary data for every (unbound)
@ -80,21 +79,35 @@ struct Options {
//
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
// no infeed is performed.
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
Client* client,
const Options& opts) {
StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
LocalClient* client, const Options& opts) {
XlaComputation computation(module.hlo().hlo_module());
std::vector<std::unique_ptr<GlobalData>> arguments;
// Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
// arguments. This is a bit involved, because we may have to convert from
// GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
// objects.
std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
std::vector<const ShapedBuffer*> argument_ptrs;
if (opts.use_fake_data) {
arguments = MakeFakeArgumentsOrDie(computation, client);
global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
for (const auto& data : global_data_arguments) {
argument_ptrs.push_back(
client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
.ValueOrDie());
}
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
client->TransferToServer(*literal));
arguments.push_back(std::move(data));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer data,
client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
scoped_shaped_buffer_arguments.push_back(std::move(data));
}
for (const auto& argument : scoped_shaped_buffer_arguments) {
argument_ptrs.push_back(&argument);
}
}
@ -149,43 +162,41 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
});
}
std::vector<GlobalData*> execute_arguments;
execute_arguments.reserve(arguments.size());
for (auto& argument : arguments) {
execute_arguments.push_back(argument.get());
std::vector<const Shape*> argument_layouts;
for (const auto& param : computation.proto().program_shape().parameters()) {
argument_layouts.push_back(&param);
}
std::unique_ptr<LocalExecutable> executable =
client->Compile(computation, argument_layouts, ExecutableBuildOptions())
.ValueOrDie();
// Run the computation num_runs times, and return the result from the last
// execution.
std::unique_ptr<Literal> result;
StreamExecutorMemoryAllocator allocator(
client->platform(),
{client->platform()->ExecutorForDevice(0).ValueOrDie()});
tensorflow::gtl::optional<ScopedShapedBuffer> result;
for (int i = 0; i < opts.num_runs; ++i) {
ExecutionProfile profile;
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
}
ExecutableRunOptions run_options;
run_options.set_execution_profile(&profile);
run_options.set_allocator(&allocator);
if (opts.print_result) {
TF_ASSIGN_OR_RETURN(
result, client->ExecuteAndTransfer(computation, execute_arguments,
&execution_options, &profile));
} else {
// If we're not printing the result, execute the computation but don't
// bother retrieving the result. This can be a significant speedup.
TF_RETURN_IF_ERROR(client
->Execute(computation, execute_arguments,
&execution_options, &profile)
.status());
}
TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
LOG(INFO) << "Execution took "
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
}
return std::move(result);
// Check that --num_runs > 0, otherwise *result below will fail with an
// unhelpful error (because the loop didn't run any iterations).
CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
client->ShapedBufferToLiteral(*result));
return std::move(*result_literal);
}
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
Client* client = ClientLibrary::LocalClientOrDie();
LocalClient* client = ClientLibrary::LocalClientOrDie();
tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
@ -202,8 +213,8 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
CHECK(opts.use_fake_data)
<< "HloProto input must be handled with --use_fake_data";
}
StatusOr<std::unique_ptr<Literal>> result_status =
ReplayComputation(snapshot, client, opts);
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
@ -211,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
continue;
}
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
if (result != nullptr) {
if (opts.print_result) {
Literal result = std::move(result_status).ValueOrDie();
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
snapshot.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(result->shape()).c_str(),
result->ToString().c_str());
ShapeUtil::HumanString(result.shape()).c_str(),
result.ToString().c_str());
if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
@ -249,9 +260,6 @@ int main(int argc, char** argv) {
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
"Whether a fake infeed shape should be generated "
"derived from the computation"),
tensorflow::Flag(
"xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
"Pass --xla_hlo_profile the last time we run the computation."),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);