Add metadata for gathering information about host compute transfers while compiling XLA.
PiperOrigin-RevId: 188102740
This commit is contained in:
parent
7efc16ed02
commit
6e99d56489
@ -58,6 +58,15 @@ xla_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "host_compute_metadata_proto",
|
||||
srcs = ["host_compute_metadata.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla",
|
||||
srcs = ["tf2xla.cc"],
|
||||
@ -149,6 +158,7 @@ cc_library(
|
||||
":common",
|
||||
":dump_graph",
|
||||
":functionalize_control_flow",
|
||||
":host_compute_metadata_proto",
|
||||
":sharding_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
|
38
tensorflow/compiler/tf2xla/host_compute_metadata.proto
Normal file
38
tensorflow/compiler/tf2xla/host_compute_metadata.proto
Normal file
@ -0,0 +1,38 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tf2xla;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "Tf2XlaProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.tf2xla";
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// TensorMetadata indicates the type and shape of a Tensor that is
|
||||
// part of a host compute transfer.
|
||||
message TensorMetadata {
|
||||
DataType type = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
// HostTransferMetadata describes a transfer either from host to device
|
||||
// or device to host. It has a key that is unique to the computation,
|
||||
// and metadata about the list of tensors being transferred.
|
||||
message HostTransferMetadata {
|
||||
// The key used to identify this transfer.
|
||||
string key = 1;
|
||||
|
||||
// For each Tensor being transferred, its type and shape.
|
||||
repeated TensorMetadata metadata = 2;
|
||||
}
|
||||
|
||||
// HostComputeMetadata describes all the sends and recvs
|
||||
// from all host compute transfer ops in a computation.
|
||||
message HostComputeMetadata {
|
||||
// Metadata about each device_to_host transfer
|
||||
repeated HostTransferMetadata device_to_host = 1;
|
||||
|
||||
// Metadata about each host_to_device transfer
|
||||
repeated HostTransferMetadata host_to_device = 2;
|
||||
}
|
@ -674,6 +674,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
VLOG(2) << "XLA output shape: "
|
||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||
|
||||
// Copy the host transfer metadata to the result.
|
||||
for (const auto& send : host_compute_sends_) {
|
||||
*result->host_compute_metadata.add_device_to_host() = send.second;
|
||||
}
|
||||
for (const auto& recv : host_compute_recvs_) {
|
||||
*result->host_compute_metadata.add_host_to_device() = recv.second;
|
||||
}
|
||||
|
||||
// Tensorflow expects a major-to-minor order of results.
|
||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||
|
||||
@ -708,4 +716,59 @@ Status XlaCompiler::GetChannelHandle(const string& key,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void SetTransfer(const string& key, const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes,
|
||||
tf2xla::HostTransferMetadata* transfer) {
|
||||
transfer->set_key(key);
|
||||
CHECK(types.size() == shapes.size());
|
||||
for (int i = 0; i < types.size(); ++i) {
|
||||
tf2xla::TensorMetadata* metadata = transfer->add_metadata();
|
||||
metadata->set_type(types[i]);
|
||||
shapes[i].AsProto(metadata->mutable_shape());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompiler::SetDeviceToHostMetadata(
|
||||
const string& key, const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes) {
|
||||
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
|
||||
}
|
||||
tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
|
||||
SetTransfer(key, types, shapes, &transfer);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::GetDeviceToHostShapes(
|
||||
const string& key, std::vector<TensorShape>* shapes) const {
|
||||
const auto iter = host_compute_sends_.find(key);
|
||||
if (iter == host_compute_sends_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"No host compute send shapes registered for key ", key);
|
||||
}
|
||||
shapes->clear();
|
||||
for (int i = 0; i < iter->second.metadata_size(); ++i) {
|
||||
TensorShape shape(iter->second.metadata(i).shape());
|
||||
shapes->push_back(shape);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::SetHostToDeviceMetadata(
|
||||
const string& key, const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes) {
|
||||
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
|
||||
}
|
||||
tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
|
||||
SetTransfer(key, types, shapes, &transfer);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -216,6 +217,10 @@ class XlaCompiler {
|
||||
// containing both constant and non-constant results.
|
||||
std::vector<OutputDescription> outputs;
|
||||
|
||||
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||
tf2xla::HostComputeMetadata host_compute_metadata;
|
||||
|
||||
// Resources whose values were updated by the computation, ordered
|
||||
// by return value position. Resource updates follow the non-constant
|
||||
// results in the outputs of XLA computation.
|
||||
@ -296,6 +301,22 @@ class XlaCompiler {
|
||||
// same XlaCompiler.
|
||||
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
||||
|
||||
// Sets the shapes and types for the device to host transfer associated with
|
||||
// 'key'.
|
||||
Status SetDeviceToHostMetadata(const string& key,
|
||||
const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes);
|
||||
|
||||
// Gets the shapes the device to host transfer associated with 'key'.
|
||||
Status GetDeviceToHostShapes(const string& key,
|
||||
std::vector<TensorShape>* shapes) const;
|
||||
|
||||
// Sets the shapes and types for the host to device transfer associated with
|
||||
// 'key'.
|
||||
Status SetHostToDeviceMetadata(const string& key,
|
||||
const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes);
|
||||
|
||||
const Options& options() const { return options_; }
|
||||
xla::Client* client() const { return options_.client; }
|
||||
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
|
||||
@ -359,6 +380,9 @@ class XlaCompiler {
|
||||
|
||||
std::unordered_map<string, xla::ChannelHandle> channels_;
|
||||
|
||||
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
|
||||
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user