[XLA:HLO] Change dumped_computation_to_text to dump the raw HLO by default.

Added a --compile flag (default false) to enable the old behavior of dumping the
HLO after compilation. It's not clear that the --compile mode is actually
useful; to actually be useful we'd probably want to add a --platform flag to
easily select the backend as well.
Change: 154379349
This commit is contained in:
A. Unique TensorFlower 2017-04-26 19:02:03 -08:00 committed by TensorFlower Gardener
parent 24316aa70c
commit be43153b21
2 changed files with 43 additions and 18 deletions

View File

@ -153,6 +153,7 @@ cc_binary(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:computation_tracker",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -34,7 +35,7 @@ limitations under the License.
namespace xla {
namespace tools {
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
}
Computation computation = computation_status.ConsumeValueOrDie();
std::unique_ptr<ProgramShape> program_shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
if (compile) {
std::unique_ptr<ProgramShape> program_shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
std::vector<const Shape*> layouts;
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
std::vector<const Shape*> layouts;
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
}
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(
computation.handle(), layouts, &program_shape->result(),
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
const HloModule& module = executable.ValueOrDie()->module();
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
local_service->backend().platform()->Name().c_str(),
module.ToString().c_str());
} else {
const ComputationTracker& tracker = local_service->computation_tracker();
UserComputation* user_computation =
tracker.Resolve(computation.handle()).ConsumeValueOrDie();
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
std::unique_ptr<HloModule> module =
tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie();
fprintf(stdout, "%s\n", module->ToString().c_str());
}
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(
computation.handle(), layouts, &program_shape->result(),
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
const HloModule& module = executable.ValueOrDie()->module();
fprintf(stdout, "HLO for %s backend:\n%s\n",
local_service->backend().platform()->Name().c_str(),
module.ToString().c_str());
}
}
@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
} // namespace xla
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
bool compile = false;
std::vector<tensorflow::Flag> flag_list = {
{"compile", &compile,
"If true, compile the computation using the default client before "
"dumping the HLO. Otherwise dump the raw (uncompiled) HLO."},
};
const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
QCHECK(parsed_flags_ok) << "\n" << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
args.pop_front(); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
xla::tools::RealMain(args, compile);
return 0;
}