[XLA] Switch replay_computation to use LocalClient.
This lets replay_computation build an executable once and run it multiple times. This is particularly important because in XLA:GPU, the first run of an executable does some autotuning and therefore is unrepresentative. This change removes --xla_hlo_profile_last_run, because I don't see how to support it in LocalClient -- LocalClient wants the do-profile bit to be set when we *compile*. (There may not be an easy fix for this; it worked with regular Client because we were recompiling every time we ran.) PiperOrigin-RevId: 198643577
This commit is contained in:
parent
0895714301
commit
49535c9da6
@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
|
|||||||
shaped_buffer);
|
shaped_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
|
||||||
|
const GlobalDataHandle& data, int replica_number) {
|
||||||
|
return local_service_->GlobalDataToShapedBuffer(data, replica_number);
|
||||||
|
}
|
||||||
|
|
||||||
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
|
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||||
|
@ -136,6 +136,11 @@ class LocalClient : public Client {
|
|||||||
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
|
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
|
||||||
const ShapedBuffer& shaped_buffer);
|
const ShapedBuffer& shaped_buffer);
|
||||||
|
|
||||||
|
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
|
||||||
|
// as long as the handle is valid.
|
||||||
|
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
|
||||||
|
const GlobalDataHandle& data, int replica_number);
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given device.
|
// Transfer the given literal to the infeed queue of the given device.
|
||||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||||
// not inherit from Client and there is no possibility of confusion with
|
// not inherit from Client and there is no possibility of confusion with
|
||||||
|
@ -260,4 +260,15 @@ StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
|||||||
/*computation_count=*/1);
|
/*computation_count=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
|
||||||
|
const GlobalDataHandle& data, int replica_number) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
|
||||||
|
if (replica_number >= buffers.size()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"replica_number %d out of range; must be less than num_replicas = %zu.",
|
||||||
|
replica_number, buffers.size());
|
||||||
|
}
|
||||||
|
return buffers[replica_number];
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -70,6 +70,11 @@ class LocalService : public Service {
|
|||||||
// the "easy" case where a single replica is a single device.
|
// the "easy" case where a single replica is a single device.
|
||||||
StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
|
StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
|
||||||
|
|
||||||
|
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
|
||||||
|
// as long as the handle is valid.
|
||||||
|
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
|
||||||
|
const GlobalDataHandle& data, int replica_number);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit LocalService(const ServiceOptions& options,
|
explicit LocalService(const ServiceOptions& options,
|
||||||
std::unique_ptr<Backend> backend);
|
std::unique_ptr<Backend> backend);
|
||||||
|
@ -68,7 +68,6 @@ struct Options {
|
|||||||
bool use_fake_data = false;
|
bool use_fake_data = false;
|
||||||
bool print_result = true;
|
bool print_result = true;
|
||||||
int num_runs = 1;
|
int num_runs = 1;
|
||||||
bool xla_hlo_profile_last_run = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Invokes the given computation passing arbitrary data for every (unbound)
|
// Invokes the given computation passing arbitrary data for every (unbound)
|
||||||
@ -80,21 +79,35 @@ struct Options {
|
|||||||
//
|
//
|
||||||
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
|
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
|
||||||
// no infeed is performed.
|
// no infeed is performed.
|
||||||
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
|
StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||||
Client* client,
|
LocalClient* client, const Options& opts) {
|
||||||
const Options& opts) {
|
|
||||||
XlaComputation computation(module.hlo().hlo_module());
|
XlaComputation computation(module.hlo().hlo_module());
|
||||||
|
|
||||||
std::vector<std::unique_ptr<GlobalData>> arguments;
|
// Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
|
||||||
|
// arguments. This is a bit involved, because we may have to convert from
|
||||||
|
// GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
|
||||||
|
// objects.
|
||||||
|
std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
|
||||||
|
std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
|
||||||
|
std::vector<const ShapedBuffer*> argument_ptrs;
|
||||||
if (opts.use_fake_data) {
|
if (opts.use_fake_data) {
|
||||||
arguments = MakeFakeArgumentsOrDie(computation, client);
|
global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||||
|
for (const auto& data : global_data_arguments) {
|
||||||
|
argument_ptrs.push_back(
|
||||||
|
client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
|
||||||
|
.ValueOrDie());
|
||||||
|
}
|
||||||
} else { // use recorded data if available
|
} else { // use recorded data if available
|
||||||
for (const auto& proto : module.arguments()) {
|
for (const auto& proto : module.arguments()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
|
||||||
Literal::CreateFromProto(proto));
|
Literal::CreateFromProto(proto));
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
TF_ASSIGN_OR_RETURN(
|
||||||
client->TransferToServer(*literal));
|
ScopedShapedBuffer data,
|
||||||
arguments.push_back(std::move(data));
|
client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
|
||||||
|
scoped_shaped_buffer_arguments.push_back(std::move(data));
|
||||||
|
}
|
||||||
|
for (const auto& argument : scoped_shaped_buffer_arguments) {
|
||||||
|
argument_ptrs.push_back(&argument);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,43 +162,41 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GlobalData*> execute_arguments;
|
std::vector<const Shape*> argument_layouts;
|
||||||
execute_arguments.reserve(arguments.size());
|
for (const auto& param : computation.proto().program_shape().parameters()) {
|
||||||
for (auto& argument : arguments) {
|
argument_layouts.push_back(¶m);
|
||||||
execute_arguments.push_back(argument.get());
|
|
||||||
}
|
}
|
||||||
|
std::unique_ptr<LocalExecutable> executable =
|
||||||
|
client->Compile(computation, argument_layouts, ExecutableBuildOptions())
|
||||||
|
.ValueOrDie();
|
||||||
|
|
||||||
// Run the computation num_runs times, and return the result from the last
|
// Run the computation num_runs times, and return the result from the last
|
||||||
// execution.
|
// execution.
|
||||||
std::unique_ptr<Literal> result;
|
StreamExecutorMemoryAllocator allocator(
|
||||||
|
client->platform(),
|
||||||
|
{client->platform()->ExecutorForDevice(0).ValueOrDie()});
|
||||||
|
tensorflow::gtl::optional<ScopedShapedBuffer> result;
|
||||||
for (int i = 0; i < opts.num_runs; ++i) {
|
for (int i = 0; i < opts.num_runs; ++i) {
|
||||||
ExecutionProfile profile;
|
ExecutionProfile profile;
|
||||||
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
|
ExecutableRunOptions run_options;
|
||||||
if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
|
run_options.set_execution_profile(&profile);
|
||||||
execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
|
run_options.set_allocator(&allocator);
|
||||||
}
|
|
||||||
|
|
||||||
if (opts.print_result) {
|
TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
result, client->ExecuteAndTransfer(computation, execute_arguments,
|
|
||||||
&execution_options, &profile));
|
|
||||||
} else {
|
|
||||||
// If we're not printing the result, execute the computation but don't
|
|
||||||
// bother retrieving the result. This can be a significant speedup.
|
|
||||||
TF_RETURN_IF_ERROR(client
|
|
||||||
->Execute(computation, execute_arguments,
|
|
||||||
&execution_options, &profile)
|
|
||||||
.status());
|
|
||||||
}
|
|
||||||
LOG(INFO) << "Execution took "
|
LOG(INFO) << "Execution took "
|
||||||
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
|
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::move(result);
|
// Check that --num_runs > 0, otherwise *result below will fail with an
|
||||||
|
// unhelpful error (because the loop didn't run any iterations).
|
||||||
|
CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
|
||||||
|
client->ShapedBufferToLiteral(*result));
|
||||||
|
return std::move(*result_literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
||||||
Client* client = ClientLibrary::LocalClientOrDie();
|
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||||
tensorflow::Env* env = tensorflow::Env::Default();
|
tensorflow::Env* env = tensorflow::Env::Default();
|
||||||
int exit_status = EXIT_SUCCESS;
|
int exit_status = EXIT_SUCCESS;
|
||||||
for (char* arg : args) {
|
for (char* arg : args) {
|
||||||
@ -202,8 +213,8 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
|||||||
CHECK(opts.use_fake_data)
|
CHECK(opts.use_fake_data)
|
||||||
<< "HloProto input must be handled with --use_fake_data";
|
<< "HloProto input must be handled with --use_fake_data";
|
||||||
}
|
}
|
||||||
StatusOr<std::unique_ptr<Literal>> result_status =
|
|
||||||
ReplayComputation(snapshot, client, opts);
|
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
|
||||||
if (!result_status.ok()) {
|
if (!result_status.ok()) {
|
||||||
fprintf(stderr, "%s: error: %s\n", arg,
|
fprintf(stderr, "%s: error: %s\n", arg,
|
||||||
result_status.status().ToString().c_str());
|
result_status.status().ToString().c_str());
|
||||||
@ -211,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
|
if (opts.print_result) {
|
||||||
if (result != nullptr) {
|
Literal result = std::move(result_status).ValueOrDie();
|
||||||
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
|
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
|
||||||
snapshot.hlo().hlo_module().name().c_str(),
|
snapshot.hlo().hlo_module().name().c_str(),
|
||||||
ShapeUtil::HumanString(result->shape()).c_str(),
|
ShapeUtil::HumanString(result.shape()).c_str(),
|
||||||
result->ToString().c_str());
|
result.ToString().c_str());
|
||||||
if (snapshot.has_result()) {
|
if (snapshot.has_result()) {
|
||||||
std::unique_ptr<Literal> literal =
|
std::unique_ptr<Literal> literal =
|
||||||
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
|
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
|
||||||
@ -249,9 +260,6 @@ int main(int argc, char** argv) {
|
|||||||
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
|
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
|
||||||
"Whether a fake infeed shape should be generated "
|
"Whether a fake infeed shape should be generated "
|
||||||
"derived from the computation"),
|
"derived from the computation"),
|
||||||
tensorflow::Flag(
|
|
||||||
"xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
|
|
||||||
"Pass --xla_hlo_profile the last time we run the computation."),
|
|
||||||
};
|
};
|
||||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
|
Loading…
Reference in New Issue
Block a user