[XLA:HLO] Change dumped_computation_to_text to dump the raw HLO by default.
Added a --compile flag (default false) to enable the old behavior of dumping the HLO after compilation. It's not clear that the --compile mode is actually useful; to actually be useful we'd probably want to add a --platform flag to easily select the backend as well. Change: 154379349
This commit is contained in:
parent
24316aa70c
commit
be43153b21
@ -153,6 +153,7 @@ cc_binary(
|
|||||||
"//tensorflow/compiler/xla/client:computation",
|
"//tensorflow/compiler/xla/client:computation",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
|
"//tensorflow/compiler/xla/service:computation_tracker",
|
||||||
"//tensorflow/compiler/xla/service:session_proto",
|
"//tensorflow/compiler/xla/service:session_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/computation.h"
|
#include "tensorflow/compiler/xla/client/computation.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_tracker.h"
|
||||||
#include "tensorflow/compiler/xla/service/service.h"
|
#include "tensorflow/compiler/xla/service/service.h"
|
||||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -34,7 +35,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace tools {
|
namespace tools {
|
||||||
|
|
||||||
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
|
||||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||||
LocalService* local_service =
|
LocalService* local_service =
|
||||||
ClientLibrary::GetXlaService(client->platform());
|
ClientLibrary::GetXlaService(client->platform());
|
||||||
@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
|||||||
}
|
}
|
||||||
Computation computation = computation_status.ConsumeValueOrDie();
|
Computation computation = computation_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
std::unique_ptr<ProgramShape> program_shape =
|
if (compile) {
|
||||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
std::unique_ptr<ProgramShape> program_shape =
|
||||||
|
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
|
|
||||||
std::vector<const Shape*> layouts;
|
std::vector<const Shape*> layouts;
|
||||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||||
layouts.push_back(&program_shape->parameters(i));
|
layouts.push_back(&program_shape->parameters(i));
|
||||||
|
}
|
||||||
|
StatusOr<std::unique_ptr<Executable>> executable =
|
||||||
|
local_service->CompileExecutable(
|
||||||
|
computation.handle(), layouts, &program_shape->result(),
|
||||||
|
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
|
||||||
|
|
||||||
|
const HloModule& module = executable.ValueOrDie()->module();
|
||||||
|
|
||||||
|
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
|
||||||
|
local_service->backend().platform()->Name().c_str(),
|
||||||
|
module.ToString().c_str());
|
||||||
|
} else {
|
||||||
|
const ComputationTracker& tracker = local_service->computation_tracker();
|
||||||
|
UserComputation* user_computation =
|
||||||
|
tracker.Resolve(computation.handle()).ConsumeValueOrDie();
|
||||||
|
VersionedComputationHandle versioned_handle =
|
||||||
|
user_computation->GetVersionedHandle();
|
||||||
|
std::unique_ptr<HloModule> module =
|
||||||
|
tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
fprintf(stdout, "%s\n", module->ToString().c_str());
|
||||||
}
|
}
|
||||||
StatusOr<std::unique_ptr<Executable>> executable =
|
|
||||||
local_service->CompileExecutable(
|
|
||||||
computation.handle(), layouts, &program_shape->result(),
|
|
||||||
/*device_ordinal=*/0, /*has_hybrid_result=*/true);
|
|
||||||
|
|
||||||
const HloModule& module = executable.ValueOrDie()->module();
|
|
||||||
|
|
||||||
fprintf(stdout, "HLO for %s backend:\n%s\n",
|
|
||||||
local_service->backend().platform()->Name().c_str(),
|
|
||||||
module.ToString().c_str());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
|||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
bool compile = false;
|
||||||
|
std::vector<tensorflow::Flag> flag_list = {
|
||||||
|
{"compile", &compile,
|
||||||
|
"If true, compile the computation using the default client before "
|
||||||
|
"dumping the HLO. Otherwise dump the raw (uncompiled) HLO."},
|
||||||
|
};
|
||||||
|
const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
|
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
|
QCHECK(parsed_flags_ok) << "\n" << usage;
|
||||||
|
|
||||||
|
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||||
|
QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
|
||||||
|
|
||||||
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
|
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
|
||||||
args.pop_front(); // Pop off the binary name, argv[0]
|
args.pop_front(); // Pop off the binary name, argv[0]
|
||||||
xla::tools::RealMain(args);
|
xla::tools::RealMain(args, compile);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user