[XLA] Redesign: delete SessionModule.

PiperOrigin-RevId: 199361402
This commit is contained in:
A. Unique TensorFlower 2018-06-05 14:46:22 -07:00 committed by TensorFlower Gardener
parent d935dd9d99
commit f0230735d1
14 changed files with 18 additions and 202 deletions

View File

@ -53,7 +53,6 @@ xla_proto_library(
deps = [
":xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:session_proto",
],
)

View File

@ -110,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",

View File

@ -185,7 +185,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
run_options, backend_->StreamBorrower(),
backend_->eigen_intra_op_thread_pool());
if (executable_->dumping()) {
if (executable_->dumping_snapshot()) {
return ExecuteAndDump(&service_options, arguments);
}
return executable_->ExecuteOnStreamWrapper(
@ -195,36 +195,36 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
executable_->session_module()->set_execution_platform(
executable_->hlo_snapshot()->set_execution_platform(
backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
executable_->ExecuteOnStream(run_options, arguments,
/*hlo_execution_profile=*/nullptr));
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module()));
TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot());
return std::move(result);
}
Status LocalExecutable::RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module) {
session_module->clear_arguments();
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*argument));
*session_module->add_arguments() = literal->ToProto();
*hlo_snapshot->add_arguments() = literal->ToProto();
}
return Status::OK();
}
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
SessionModule* session_module) {
session_module->clear_result();
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*result));
*session_module->mutable_result() = literal->ToProto();
*hlo_snapshot->mutable_result() = literal->ToProto();
return Status::OK();
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -78,11 +79,10 @@ class LocalExecutable {
// proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module);
HloSnapshot* hlo_snapshot);
// Records the result of the computation in a SessionModule proto.
Status RecordResult(const ShapedBuffer* result,
SessionModule* session_module);
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(

View File

@ -21,13 +21,6 @@ load(
"tf_proto_library_py",
)
xla_proto_library(
name = "session_proto",
srcs = ["session.proto"],
visibility = ["//visibility:public"],
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
)
xla_proto_library(
name = "hlo_proto",
srcs = ["hlo.proto"],
@ -608,7 +601,6 @@ cc_library(
":hlo_module_config",
":hlo_proto_util",
":platform_util",
":session_proto",
":source_map_util",
":transfer_manager",
":versioned_computation_handle",
@ -766,7 +758,6 @@ cc_library(
":hlo_graph_dumper",
":hlo_proto",
":pool",
":session_proto",
":shaped_buffer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
@ -870,7 +861,6 @@ cc_library(
hdrs = ["channel_tracker.h"],
deps = [
":hlo",
":session_proto",
":versioned_computation_handle",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <map>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -129,20 +129,6 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
return return_value;
}
Status Executable::DumpSessionModule() {
TF_RET_CHECK(dumping());
const string& directory_path =
module_config().debug_options().xla_dump_executions_to();
VersionedComputationHandle versioned_handle = entry_computation_handle();
// This filename does not include the version number because the computation
// is only ever executed at one version.
string filename = tensorflow::strings::Printf(
"computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(),
session_module_->entry().name().c_str(), ++execution_count_);
return Executable::DumpToDirectory(directory_path, filename,
*session_module_);
}
Status Executable::DumpHloSnapshot() {
TF_RET_CHECK(dumping_snapshot());
TF_RET_CHECK(hlo_snapshot_->has_hlo() &&
@ -156,26 +142,6 @@ Status Executable::DumpHloSnapshot() {
return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_);
}
/* static */ Status Executable::DumpToDirectory(
const string& directory_path, string filename,
const SessionModule& session_module) {
tensorflow::Env* env = tensorflow::Env::Default();
if (!env->IsDirectory(directory_path).ok()) {
// NB! CreateDir does not work reliably with multiple XLA threads -- two
// threads can race to observe the absence of the dump directory and
// simultaneously try to create it, causing the "losing" thread to get a
// "directory already exists" error.
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path));
}
filename = SanitizeFileName(std::move(filename));
string file_path = tensorflow::io::JoinPath(directory_path, filename);
string result;
TF_RET_CHECK(
tensorflow::SerializeToStringDeterministic(session_module, &result));
return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path,
result);
}
/* static */ Status Executable::DumpToDirectory(
const string& directory_path, string filename,
const HloSnapshot& hlo_session) {

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -144,14 +143,6 @@ class Executable {
return hlo_module_->config().host_entry_computation_layout().result_shape();
}
// TODO(b/74197823): Delete the session module dumping helpers.
void set_session_module(std::unique_ptr<xla::SessionModule> session_module) {
session_module_ = std::move(session_module);
}
bool dumping() const { return session_module_ != nullptr; }
SessionModule* session_module() const { return session_module_.get(); }
Status DumpSessionModule();
// Dumping helpers.
void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) {
hlo_snapshot_ = std::move(hlo_snapshot);
@ -160,10 +151,6 @@ class Executable {
HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); }
Status DumpHloSnapshot();
// Dump session_module to directory_path/filename.
static Status DumpToDirectory(const string& directory_path, string filename,
const SessionModule& session_module);
// Dump hlo snapshot to directory_path/filename.
static Status DumpToDirectory(const string& directory_path, string filename,
const HloSnapshot& hlo_session);
@ -179,9 +166,6 @@ class Executable {
// around.
const std::unique_ptr<const HloModule> hlo_module_;
// SessionModule this was compiled from. Null if not dumping executions.
std::unique_ptr<SessionModule> session_module_;
// HloSnapshot this was compiled from. Null if not dumping executions.
std::unique_ptr<HloSnapshot> hlo_snapshot_;

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
@ -62,33 +61,6 @@ namespace xla {
namespace {
// Records the arguments used to invoke a computation in a SessionModule
// proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
se::StreamExecutor* executor, TransferManager* transfer_manager,
SessionModule* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
transfer_manager->TransferLiteralFromDevice(executor, *argument));
*module->add_arguments() = literal->ToProto();
}
return Status::OK();
}
// Records the result of a computation in a SessionModule proto.
Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
TransferManager* transfer_manager, SessionModule* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
transfer_manager->TransferLiteralFromDevice(executor, result));
*module->mutable_result() = literal->ToProto();
return Status::OK();
}
// Records the arguments used to invoke a computation in an HloSnapshot proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -1,85 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This proto file defines messages which store the state of XLA
// computations within the XLA service. A computation is stored as a record
// of the operation requests used to build it.
syntax = "proto3";
import "tensorflow/compiler/xla/xla_data.proto";
package xla;
// Describes a single operation request.
message OperationRequest {
ComputationDataHandle output_handle = 1;
Shape output_shape = 2;
// For operations which call embedded computations such as "Map", these are
// the version(s) that the embedded computation should be called at. A version
// value of a computation is the ComputationDataHandle of the root of the
// computation at the point in time.
//
// "Call", "Map", "Reduce", and "ReduceWindow" operations take a single
// embedded computation so this field will have a single value for those
// operations.
//
// "While" operation takes two; index 0 is the "condition" version and index 1
// is the "body" version.
repeated int64 embedded_computation_versions = 3;
// The actual request, which in itself is a tagged union of all possible
// operation request types.
OpRequest request = 4;
}
// Describes a sequence of operation requests which define an XLA
// computation.
message SessionComputation {
string name = 1;
// The ComputationHandle used to refer to this computation in the XLA
// service.
ComputationHandle computation_handle = 2;
// Map from ComputationDataHandle value to operation request. The highest
// ComputationDataHandle value corresponds to the root of the computation.
map<int64, OperationRequest> requests = 3;
}
// Describes a group of SessionComputations with an "entry point" computation
// that may refer to the other non-entry (AKA embedded) computations.
//
// This message is used to serialize a computation that has been built via the
// XLA service API, along with its dependencies, for purposes such as
// analysis/replay/file-storage.
message SessionModule {
// The entry computation, which was requested for serialization. This may have
// referred to embedded computations, which are reflected below.
SessionComputation entry = 1;
// Embedded computations that are transitively referred to by the entry
// computation.
repeated SessionComputation embedded_computations = 2;
// The arguments passed to the computation.
repeated LiteralProto arguments = 3;
// The result of the computation.
LiteralProto result = 4;
// The name of the platform used to run the computation.
string execution_platform = 5;
}

View File

@ -135,7 +135,7 @@ tf_cc_binary(
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)

View File

@ -21,7 +21,7 @@ limitations under the License.
#include <unistd.h>
#include <string>
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/env.h"
@ -33,7 +33,7 @@ namespace xla {
namespace tools {
void RealMain(const string& mode, const string& path) {
SessionModule module;
HloSnapshot module;
tensorflow::Env* env = tensorflow::Env::Default();
if (mode == "txt2bin") {
TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module));

View File

@ -17,7 +17,6 @@ syntax = "proto3";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/service/session.proto";
package xla;
@ -230,14 +229,6 @@ message SnapshotComputationRequest {
ComputationHandle computation = 1;
}
message SnapshotComputationResponse {
SessionModule module = 1;
}
message LoadComputationSnapshotRequest {
SessionModule module = 1;
}
message LoadComputationSnapshotResponse {
ComputationHandle computation = 1;
}