[XLA:HLO] Change dumped_computation_to_text to dump the raw HLO by default.
Added a --compile flag (default false) to enable the old behavior of dumping the HLO after compilation. It's not clear that the --compile mode is actually useful; to actually be useful we'd probably want to add a --platform flag to easily select the backend as well. Change: 154379349
This commit is contained in:
parent
24316aa70c
commit
be43153b21
@ -153,6 +153,7 @@ cc_binary(
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:computation_tracker",
|
||||
"//tensorflow/compiler/xla/service:session_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_tracker.h"
|
||||
#include "tensorflow/compiler/xla/service/service.h"
|
||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -34,7 +35,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
||||
void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
LocalService* local_service =
|
||||
ClientLibrary::GetXlaService(client->platform());
|
||||
@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
||||
}
|
||||
Computation computation = computation_status.ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<ProgramShape> program_shape =
|
||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
if (compile) {
|
||||
std::unique_ptr<ProgramShape> program_shape =
|
||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
|
||||
std::vector<const Shape*> layouts;
|
||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||
layouts.push_back(&program_shape->parameters(i));
|
||||
std::vector<const Shape*> layouts;
|
||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||
layouts.push_back(&program_shape->parameters(i));
|
||||
}
|
||||
StatusOr<std::unique_ptr<Executable>> executable =
|
||||
local_service->CompileExecutable(
|
||||
computation.handle(), layouts, &program_shape->result(),
|
||||
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
|
||||
|
||||
const HloModule& module = executable.ValueOrDie()->module();
|
||||
|
||||
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
|
||||
local_service->backend().platform()->Name().c_str(),
|
||||
module.ToString().c_str());
|
||||
} else {
|
||||
const ComputationTracker& tracker = local_service->computation_tracker();
|
||||
UserComputation* user_computation =
|
||||
tracker.Resolve(computation.handle()).ConsumeValueOrDie();
|
||||
VersionedComputationHandle versioned_handle =
|
||||
user_computation->GetVersionedHandle();
|
||||
std::unique_ptr<HloModule> module =
|
||||
tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie();
|
||||
|
||||
fprintf(stdout, "%s\n", module->ToString().c_str());
|
||||
}
|
||||
StatusOr<std::unique_ptr<Executable>> executable =
|
||||
local_service->CompileExecutable(
|
||||
computation.handle(), layouts, &program_shape->result(),
|
||||
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
|
||||
|
||||
const HloModule& module = executable.ValueOrDie()->module();
|
||||
|
||||
fprintf(stdout, "HLO for %s backend:\n%s\n",
|
||||
local_service->backend().platform()->Name().c_str(),
|
||||
module.ToString().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
bool compile = false;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
{"compile", &compile,
|
||||
"If true, compile the computation using the default client before "
|
||||
"dumping the HLO. Otherwise dump the raw (uncompiled) HLO."},
|
||||
};
|
||||
const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
QCHECK(parsed_flags_ok) << "\n" << usage;
|
||||
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
|
||||
|
||||
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
|
||||
args.pop_front(); // Pop off the binary name, argv[0]
|
||||
xla::tools::RealMain(args);
|
||||
xla::tools::RealMain(args, compile);
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user