Add metadata for gathering information about host compute transfers while compiling XLA.

PiperOrigin-RevId: 188102740
This commit is contained in:
A. Unique TensorFlower 2018-03-06 16:46:54 -08:00 committed by TensorFlower Gardener
parent 7efc16ed02
commit 6e99d56489
4 changed files with 135 additions and 0 deletions

View File

@ -58,6 +58,15 @@ xla_proto_library(
], ],
) )
xla_proto_library(
name = "host_compute_metadata_proto",
srcs = ["host_compute_metadata.proto"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library( cc_library(
name = "tf2xla", name = "tf2xla",
srcs = ["tf2xla.cc"], srcs = ["tf2xla.cc"],
@ -149,6 +158,7 @@ cc_library(
":common", ":common",
":dump_graph", ":dump_graph",
":functionalize_control_flow", ":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util", ":sharding_util",
":tf2xla_util", ":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:util",

View File

@ -0,0 +1,38 @@
syntax = "proto3";
package tensorflow.tf2xla;
option cc_enable_arenas = true;
option java_outer_classname = "Tf2XlaProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.tf2xla";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
// TensorMetadata indicates the type and shape of a Tensor that is
// part of a host compute transfer.
message TensorMetadata {
DataType type = 1;
TensorShapeProto shape = 2;
}
// HostTransferMetadata describes a transfer either from host to device
// or device to host. It has a key that is unique to the computation,
// and metadata about the list of tensors being transferred.
message HostTransferMetadata {
// The key used to identify this transfer.
string key = 1;
// For each Tensor being transferred, its type and shape.
repeated TensorMetadata metadata = 2;
}
// HostComputeMetadata describes all the sends and recvs
// from all host compute transfer ops in a computation.
message HostComputeMetadata {
// Metadata about each device_to_host transfer
repeated HostTransferMetadata device_to_host = 1;
// Metadata about each host_to_device transfer
repeated HostTransferMetadata host_to_device = 2;
}

View File

@ -674,6 +674,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "XLA output shape: " VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape); << xla::ShapeUtil::HumanString(result->xla_output_shape);
// Copy the host transfer metadata to the result.
for (const auto& send : host_compute_sends_) {
*result->host_compute_metadata.add_device_to_host() = send.second;
}
for (const auto& recv : host_compute_recvs_) {
*result->host_compute_metadata.add_host_to_device() = recv.second;
}
// Tensorflow expects a major-to-minor order of results. // Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
@ -708,4 +716,59 @@ Status XlaCompiler::GetChannelHandle(const string& key,
return Status::OK(); return Status::OK();
} }
namespace {
void SetTransfer(const string& key, const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes,
tf2xla::HostTransferMetadata* transfer) {
transfer->set_key(key);
CHECK(types.size() == shapes.size());
for (int i = 0; i < types.size(); ++i) {
tf2xla::TensorMetadata* metadata = transfer->add_metadata();
metadata->set_type(types[i]);
shapes[i].AsProto(metadata->mutable_shape());
}
}
} // namespace
Status XlaCompiler::SetDeviceToHostMetadata(
const string& key, const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes) {
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
}
tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
SetTransfer(key, types, shapes, &transfer);
return Status::OK();
}
Status XlaCompiler::GetDeviceToHostShapes(
const string& key, std::vector<TensorShape>* shapes) const {
const auto iter = host_compute_sends_.find(key);
if (iter == host_compute_sends_.end()) {
return errors::InvalidArgument(
"No host compute send shapes registered for key ", key);
}
shapes->clear();
for (int i = 0; i < iter->second.metadata_size(); ++i) {
TensorShape shape(iter->second.metadata(i).shape());
shapes->push_back(shape);
}
return Status::OK();
}
Status XlaCompiler::SetHostToDeviceMetadata(
const string& key, const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes) {
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
}
tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
SetTransfer(key, types, shapes, &transfer);
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
@ -216,6 +217,10 @@ class XlaCompiler {
// containing both constant and non-constant results. // containing both constant and non-constant results.
std::vector<OutputDescription> outputs; std::vector<OutputDescription> outputs;
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
// matching RecvAtHost/SendFromHost Ops in the outer graph.
tf2xla::HostComputeMetadata host_compute_metadata;
// Resources whose values were updated by the computation, ordered // Resources whose values were updated by the computation, ordered
// by return value position. Resource updates follow the non-constant // by return value position. Resource updates follow the non-constant
// results in the outputs of XLA computation. // results in the outputs of XLA computation.
@ -296,6 +301,22 @@ class XlaCompiler {
// same XlaCompiler. // same XlaCompiler.
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
// Sets the shapes and types for the device to host transfer associated with
// 'key'.
Status SetDeviceToHostMetadata(const string& key,
const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes);
// Gets the shapes the device to host transfer associated with 'key'.
Status GetDeviceToHostShapes(const string& key,
std::vector<TensorShape>* shapes) const;
// Sets the shapes and types for the host to device transfer associated with
// 'key'.
Status SetHostToDeviceMetadata(const string& key,
const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes);
const Options& options() const { return options_; } const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; } xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
@ -359,6 +380,9 @@ class XlaCompiler {
std::unordered_map<string, xla::ChannelHandle> channels_; std::unordered_map<string, xla::ChannelHandle> channels_;
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
}; };