[XLA] Redesign: delete Client::LoadSnapeshot(SessionModule). This is a precondition to delete xla::Computation.

PiperOrigin-RevId: 197033641
This commit is contained in:
A. Unique TensorFlower 2018-05-17 12:31:17 -07:00 committed by TensorFlower Gardener
parent 01dbc6ac45
commit 9a815b422a
10 changed files with 49 additions and 105 deletions

View File

@ -76,7 +76,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)

View File

@ -221,20 +221,6 @@ StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
return Literal::CreateFromProto(response.literal());
}
StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
LoadComputationSnapshotRequest request;
*request.mutable_module() = module;
LoadComputationSnapshotResponse response;
Status s = stub_->LoadComputationSnapshot(&request, &response);
if (!s.ok()) {
return s;
}
VLOG(1) << "load snapshot response: " << response.ShortDebugString();
return Computation(stub_, response.computation());
}
StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
return XlaComputation(module.hlo().hlo_module());

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -253,9 +253,6 @@ class Client {
// two computations via a pair of Send and Recv instructions.
StatusOr<ChannelHandle> CreateChannelHandle();
StatusOr<Computation> LoadSnapshot(const SessionModule& module);
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
ServiceInterface* stub() { return stub_; }

View File

@ -40,7 +40,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)
@ -65,8 +65,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
@ -89,7 +89,6 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
@ -169,8 +168,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:computation_tracker",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
@ -188,8 +187,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
@ -207,8 +206,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)

View File

@ -17,7 +17,7 @@ limitations under the License.
//
// Dumps a graphviz URL for a snapshot computation to the command line.
//
// some_binary_snapshot_proto is obtained by serializing the SessionModule from
// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// The GraphViz URL is placed into the log stderr, whereas computation
@ -30,11 +30,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -49,10 +48,11 @@ namespace tools {
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
SessionModule module;
HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
XlaComputation computation =
client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
ComputationStats stats =

View File

@ -21,11 +21,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
for (char* arg : args) {
SessionModule session_module;
HloSnapshot snapshot;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
&session_module));
auto computation_status = client->LoadSnapshot(session_module);
&snapshot));
auto computation_status = client->LoadSnapshot(snapshot);
if (!computation_status.ok()) {
fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
computation_status.status().ToString().c_str());
continue;
}
Computation computation = computation_status.ConsumeValueOrDie();
XlaComputation computation = computation_status.ConsumeValueOrDie();
std::unique_ptr<ProgramShape> program_shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation.handle(), layouts,
build_options);
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();

View File

@ -19,11 +19,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -40,16 +39,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
for (char* arg : args) {
SessionModule session_module;
HloSnapshot snapshot;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
&session_module));
auto computation_status = client->LoadSnapshot(session_module);
&snapshot));
auto computation_status = client->LoadSnapshot(snapshot);
if (!computation_status.ok()) {
fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
computation_status.status().ToString().c_str());
continue;
}
Computation computation = computation_status.ConsumeValueOrDie();
XlaComputation computation = computation_status.ConsumeValueOrDie();
if (compile) {
std::unique_ptr<ProgramShape> program_shape =
@ -65,8 +64,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation.handle(), layouts,
build_options);
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
@ -74,13 +72,11 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
local_service->backend().platform()->Name().c_str(),
module.ToString(HloPrintOptions::ShortParsable()).c_str());
} else {
const ComputationTracker& tracker = local_service->computation_tracker();
UserComputation* user_computation =
tracker.Resolve(computation.handle()).ConsumeValueOrDie();
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
DebugOptions())
.ConsumeValueOrDie();
std::unique_ptr<HloModule> module =
tracker.BuildHloModule(versioned_handle, HloModuleConfig())
HloModule::CreateFromProto(computation.proto(), config)
.ConsumeValueOrDie();
fprintf(stdout, "%s\n",

View File

@ -28,11 +28,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -48,10 +47,11 @@ namespace tools {
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
SessionModule module;
HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
XlaComputation computation =
client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
debug_options.set_xla_hlo_dump_as_graphdef(true);

View File

@ -17,7 +17,7 @@ limitations under the License.
//
// Replays computations and shows the results on the command line.
//
// some_binary_snapshot_proto is obtained by serializing the SessionModule from
// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// Computations that require arguments can be replayed using fake data by
@ -36,14 +36,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -76,13 +74,9 @@ struct Options {
//
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
// otherwise, no infeed is performed.
template <typename ModuleT>
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const ModuleT& module,
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
Client* client,
const Options& opts) {
static_assert(std::is_same<ModuleT, HloSnapshot>::value ||
std::is_same<ModuleT, SessionModule>::value,
"Proto must be in HloSnapshot or SessionModule format");
TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module));
std::vector<std::unique_ptr<GlobalData>> arguments;
@ -161,40 +155,13 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
for (char* arg : args) {
HloSnapshot snapshot;
auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
if (status.ok()) {
StatusOr<std::unique_ptr<Literal>> result_status =
ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
exit_status = EXIT_FAILURE;
continue;
}
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
if (result != nullptr) {
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
snapshot.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(result->shape()).c_str(),
result->ToString().c_str());
if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
literal->ToString().c_str());
}
}
if (!status.ok()) {
fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg,
status.ToString().c_str());
continue;
}
fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n",
arg, status.ToString().c_str());
SessionModule module;
TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
StatusOr<std::unique_ptr<Literal>> result_status =
ReplayComputation(module, client, opts);
ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
@ -204,14 +171,15 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
if (result != nullptr) {
fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
snapshot.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(result->shape()).c_str(),
result->ToString().c_str());
if (module.has_result()) {
if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(module.result().shape()).c_str(),
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
literal->ToString().c_str());
}
}

View File

@ -18,7 +18,7 @@ limitations under the License.
// Shows the signature (ProgramShape) of binary snapshot proto(s) on the command
// line.
//
// some_binary_snapshot_proto is obtained by serializing the SessionModule from
// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// The output format is:
@ -31,9 +31,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -49,13 +48,14 @@ namespace tools {
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
SessionModule module;
HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
auto computation = client->LoadSnapshot(module).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(),
fprintf(stdout, "%s: %s :: %s\n", arg,
module.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(*shape).c_str());
}
}