[XLA] Switch replay_computation to use LocalClient.
This lets replay_computation build an executable once and run it multiple times. This is particularly important because in XLA:GPU, the first run of an executable does some autotuning and therefore is unrepresentative. This change removes --xla_hlo_profile_last_run, because I don't see how to support it in LocalClient -- LocalClient wants the do-profile bit to be set when we *compile*. (There may not be an easy fix for this; it worked with regular Client because we were recompiling every time we ran.) PiperOrigin-RevId: 198643577
This commit is contained in:
parent
0895714301
commit
49535c9da6
@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
|
||||
shaped_buffer);
|
||||
}
|
||||
|
||||
StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
|
||||
const GlobalDataHandle& data, int replica_number) {
|
||||
return local_service_->GlobalDataToShapedBuffer(data, replica_number);
|
||||
}
|
||||
|
||||
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
|
||||
int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
|
@ -136,6 +136,11 @@ class LocalClient : public Client {
|
||||
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
|
||||
const ShapedBuffer& shaped_buffer);
|
||||
|
||||
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
|
||||
// as long as the handle is valid.
|
||||
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
|
||||
const GlobalDataHandle& data, int replica_number);
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given device.
|
||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||
// not inherit from Client and there is no possibility of confusion with
|
||||
|
@ -260,4 +260,15 @@ StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
||||
/*computation_count=*/1);
|
||||
}
|
||||
|
||||
StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
|
||||
const GlobalDataHandle& data, int replica_number) {
|
||||
TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
|
||||
if (replica_number >= buffers.size()) {
|
||||
return InvalidArgument(
|
||||
"replica_number %d out of range; must be less than num_replicas = %zu.",
|
||||
replica_number, buffers.size());
|
||||
}
|
||||
return buffers[replica_number];
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -70,6 +70,11 @@ class LocalService : public Service {
|
||||
// the "easy" case where a single replica is a single device.
|
||||
StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
|
||||
|
||||
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
|
||||
// as long as the handle is valid.
|
||||
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
|
||||
const GlobalDataHandle& data, int replica_number);
|
||||
|
||||
private:
|
||||
explicit LocalService(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> backend);
|
||||
|
@ -68,7 +68,6 @@ struct Options {
|
||||
bool use_fake_data = false;
|
||||
bool print_result = true;
|
||||
int num_runs = 1;
|
||||
bool xla_hlo_profile_last_run = false;
|
||||
};
|
||||
|
||||
// Invokes the given computation passing arbitrary data for every (unbound)
|
||||
@ -80,21 +79,35 @@ struct Options {
|
||||
//
|
||||
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
|
||||
// no infeed is performed.
|
||||
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
|
||||
Client* client,
|
||||
const Options& opts) {
|
||||
StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
LocalClient* client, const Options& opts) {
|
||||
XlaComputation computation(module.hlo().hlo_module());
|
||||
|
||||
std::vector<std::unique_ptr<GlobalData>> arguments;
|
||||
// Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
|
||||
// arguments. This is a bit involved, because we may have to convert from
|
||||
// GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
|
||||
// objects.
|
||||
std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
|
||||
std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
|
||||
std::vector<const ShapedBuffer*> argument_ptrs;
|
||||
if (opts.use_fake_data) {
|
||||
arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||
global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||
for (const auto& data : global_data_arguments) {
|
||||
argument_ptrs.push_back(
|
||||
client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
|
||||
.ValueOrDie());
|
||||
}
|
||||
} else { // use recorded data if available
|
||||
for (const auto& proto : module.arguments()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
|
||||
Literal::CreateFromProto(proto));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
||||
client->TransferToServer(*literal));
|
||||
arguments.push_back(std::move(data));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer data,
|
||||
client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
|
||||
scoped_shaped_buffer_arguments.push_back(std::move(data));
|
||||
}
|
||||
for (const auto& argument : scoped_shaped_buffer_arguments) {
|
||||
argument_ptrs.push_back(&argument);
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,43 +162,41 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<GlobalData*> execute_arguments;
|
||||
execute_arguments.reserve(arguments.size());
|
||||
for (auto& argument : arguments) {
|
||||
execute_arguments.push_back(argument.get());
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
for (const auto& param : computation.proto().program_shape().parameters()) {
|
||||
argument_layouts.push_back(¶m);
|
||||
}
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client->Compile(computation, argument_layouts, ExecutableBuildOptions())
|
||||
.ValueOrDie();
|
||||
|
||||
// Run the computation num_runs times, and return the result from the last
|
||||
// execution.
|
||||
std::unique_ptr<Literal> result;
|
||||
StreamExecutorMemoryAllocator allocator(
|
||||
client->platform(),
|
||||
{client->platform()->ExecutorForDevice(0).ValueOrDie()});
|
||||
tensorflow::gtl::optional<ScopedShapedBuffer> result;
|
||||
for (int i = 0; i < opts.num_runs; ++i) {
|
||||
ExecutionProfile profile;
|
||||
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
|
||||
if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
|
||||
execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
|
||||
}
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_execution_profile(&profile);
|
||||
run_options.set_allocator(&allocator);
|
||||
|
||||
if (opts.print_result) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result, client->ExecuteAndTransfer(computation, execute_arguments,
|
||||
&execution_options, &profile));
|
||||
} else {
|
||||
// If we're not printing the result, execute the computation but don't
|
||||
// bother retrieving the result. This can be a significant speedup.
|
||||
TF_RETURN_IF_ERROR(client
|
||||
->Execute(computation, execute_arguments,
|
||||
&execution_options, &profile)
|
||||
.status());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
|
||||
LOG(INFO) << "Execution took "
|
||||
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
|
||||
}
|
||||
|
||||
return std::move(result);
|
||||
// Check that --num_runs > 0, otherwise *result below will fail with an
|
||||
// unhelpful error (because the loop didn't run any iterations).
|
||||
CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
|
||||
client->ShapedBufferToLiteral(*result));
|
||||
return std::move(*result_literal);
|
||||
}
|
||||
|
||||
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
||||
Client* client = ClientLibrary::LocalClientOrDie();
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
int exit_status = EXIT_SUCCESS;
|
||||
for (char* arg : args) {
|
||||
@ -202,8 +213,8 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
||||
CHECK(opts.use_fake_data)
|
||||
<< "HloProto input must be handled with --use_fake_data";
|
||||
}
|
||||
StatusOr<std::unique_ptr<Literal>> result_status =
|
||||
ReplayComputation(snapshot, client, opts);
|
||||
|
||||
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
|
||||
if (!result_status.ok()) {
|
||||
fprintf(stderr, "%s: error: %s\n", arg,
|
||||
result_status.status().ToString().c_str());
|
||||
@ -211,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
|
||||
if (result != nullptr) {
|
||||
if (opts.print_result) {
|
||||
Literal result = std::move(result_status).ValueOrDie();
|
||||
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
|
||||
snapshot.hlo().hlo_module().name().c_str(),
|
||||
ShapeUtil::HumanString(result->shape()).c_str(),
|
||||
result->ToString().c_str());
|
||||
ShapeUtil::HumanString(result.shape()).c_str(),
|
||||
result.ToString().c_str());
|
||||
if (snapshot.has_result()) {
|
||||
std::unique_ptr<Literal> literal =
|
||||
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
|
||||
@ -249,9 +260,6 @@ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
|
||||
"Whether a fake infeed shape should be generated "
|
||||
"derived from the computation"),
|
||||
tensorflow::Flag(
|
||||
"xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
|
||||
"Pass --xla_hlo_profile the last time we run the computation."),
|
||||
};
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
|
Loading…
Reference in New Issue
Block a user