From be43153b21fba4280f0f2a016242614a21d115f3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 19:02:03 -0800 Subject: [PATCH] [XLA:HLO] Change dumped_computation_to_text to dump the raw HLO by default. Added a --compile flag (default false) to enable the old behavior of dumping the HLO after compilation. It's not clear that the --compile mode is actually useful; to actually be useful we'd probably want to add a --platform flag to easily select the backend as well. Change: 154379349 --- tensorflow/compiler/xla/tools/BUILD | 1 + .../xla/tools/dumped_computation_to_text.cc | 60 +++++++++++++------ 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 1d9baf5de10..535e5b605b4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -153,6 +153,7 @@ cc_binary( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 8b96e134897..1f0ca31d6d6 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } Computation computation = computation_status.ConsumeValueOrDie(); - std::unique_ptr program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); + if (compile) { + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); - std::vector layouts; - for (int i = 0; i < program_shape->parameters_size(); ++i) { - layouts.push_back(&program_shape->parameters(i)); + std::vector layouts; + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); + + fprintf(stdout, "HLO compiled for %s backend:\n%s\n", + local_service->backend().platform()->Name().c_str(), + module.ToString().c_str()); + } else { + const ComputationTracker& tracker = local_service->computation_tracker(); + UserComputation* user_computation = + tracker.Resolve(computation.handle()).ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + std::unique_ptr module = + tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie(); + + fprintf(stdout, "%s\n", module->ToString().c_str()); } - StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); - - const HloModule& module = executable.ValueOrDie()->module(); - - fprintf(stdout, "HLO for %s backend:\n%s\n", - local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); } } @@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); + bool compile = false; + std::vector flag_list = { + {"compile", &compile, + "If true, compile the computation using the default client before " + "dumping the HLO. Otherwise dump the raw (uncompiled) HLO."}, + }; + const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); + xla::tools::RealMain(args, compile); return 0; }