Add metadata for gathering information about host compute transfers while compiling XLA.
PiperOrigin-RevId: 188102740
This commit is contained in:
parent
7efc16ed02
commit
6e99d56489
@ -58,6 +58,15 @@ xla_proto_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
xla_proto_library(
|
||||||
|
name = "host_compute_metadata_proto",
|
||||||
|
srcs = ["host_compute_metadata.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf2xla",
|
name = "tf2xla",
|
||||||
srcs = ["tf2xla.cc"],
|
srcs = ["tf2xla.cc"],
|
||||||
@ -149,6 +158,7 @@ cc_library(
|
|||||||
":common",
|
":common",
|
||||||
":dump_graph",
|
":dump_graph",
|
||||||
":functionalize_control_flow",
|
":functionalize_control_flow",
|
||||||
|
":host_compute_metadata_proto",
|
||||||
":sharding_util",
|
":sharding_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
|
38
tensorflow/compiler/tf2xla/host_compute_metadata.proto
Normal file
38
tensorflow/compiler/tf2xla/host_compute_metadata.proto
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow.tf2xla;
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
option java_outer_classname = "Tf2XlaProtos";
|
||||||
|
option java_multiple_files = true;
|
||||||
|
option java_package = "org.tensorflow.tf2xla";
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
|
// TensorMetadata indicates the type and shape of a Tensor that is
|
||||||
|
// part of a host compute transfer.
|
||||||
|
message TensorMetadata {
|
||||||
|
DataType type = 1;
|
||||||
|
TensorShapeProto shape = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostTransferMetadata describes a transfer either from host to device
|
||||||
|
// or device to host. It has a key that is unique to the computation,
|
||||||
|
// and metadata about the list of tensors being transferred.
|
||||||
|
message HostTransferMetadata {
|
||||||
|
// The key used to identify this transfer.
|
||||||
|
string key = 1;
|
||||||
|
|
||||||
|
// For each Tensor being transferred, its type and shape.
|
||||||
|
repeated TensorMetadata metadata = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostComputeMetadata describes all the sends and recvs
|
||||||
|
// from all host compute transfer ops in a computation.
|
||||||
|
message HostComputeMetadata {
|
||||||
|
// Metadata about each device_to_host transfer
|
||||||
|
repeated HostTransferMetadata device_to_host = 1;
|
||||||
|
|
||||||
|
// Metadata about each host_to_device transfer
|
||||||
|
repeated HostTransferMetadata host_to_device = 2;
|
||||||
|
}
|
@ -674,6 +674,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
VLOG(2) << "XLA output shape: "
|
VLOG(2) << "XLA output shape: "
|
||||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||||
|
|
||||||
|
// Copy the host transfer metadata to the result.
|
||||||
|
for (const auto& send : host_compute_sends_) {
|
||||||
|
*result->host_compute_metadata.add_device_to_host() = send.second;
|
||||||
|
}
|
||||||
|
for (const auto& recv : host_compute_recvs_) {
|
||||||
|
*result->host_compute_metadata.add_host_to_device() = recv.second;
|
||||||
|
}
|
||||||
|
|
||||||
// Tensorflow expects a major-to-minor order of results.
|
// Tensorflow expects a major-to-minor order of results.
|
||||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||||
|
|
||||||
@ -708,4 +716,59 @@ Status XlaCompiler::GetChannelHandle(const string& key,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void SetTransfer(const string& key, const std::vector<DataType>& types,
|
||||||
|
const std::vector<TensorShape>& shapes,
|
||||||
|
tf2xla::HostTransferMetadata* transfer) {
|
||||||
|
transfer->set_key(key);
|
||||||
|
CHECK(types.size() == shapes.size());
|
||||||
|
for (int i = 0; i < types.size(); ++i) {
|
||||||
|
tf2xla::TensorMetadata* metadata = transfer->add_metadata();
|
||||||
|
metadata->set_type(types[i]);
|
||||||
|
shapes[i].AsProto(metadata->mutable_shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status XlaCompiler::SetDeviceToHostMetadata(
|
||||||
|
const string& key, const std::vector<DataType>& types,
|
||||||
|
const std::vector<TensorShape>& shapes) {
|
||||||
|
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
|
||||||
|
}
|
||||||
|
tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
|
||||||
|
SetTransfer(key, types, shapes, &transfer);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status XlaCompiler::GetDeviceToHostShapes(
|
||||||
|
const string& key, std::vector<TensorShape>* shapes) const {
|
||||||
|
const auto iter = host_compute_sends_.find(key);
|
||||||
|
if (iter == host_compute_sends_.end()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"No host compute send shapes registered for key ", key);
|
||||||
|
}
|
||||||
|
shapes->clear();
|
||||||
|
for (int i = 0; i < iter->second.metadata_size(); ++i) {
|
||||||
|
TensorShape shape(iter->second.metadata(i).shape());
|
||||||
|
shapes->push_back(shape);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status XlaCompiler::SetHostToDeviceMetadata(
|
||||||
|
const string& key, const std::vector<DataType>& types,
|
||||||
|
const std::vector<TensorShape>& shapes) {
|
||||||
|
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
|
||||||
|
}
|
||||||
|
tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
|
||||||
|
SetTransfer(key, types, shapes, &transfer);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -216,6 +217,10 @@ class XlaCompiler {
|
|||||||
// containing both constant and non-constant results.
|
// containing both constant and non-constant results.
|
||||||
std::vector<OutputDescription> outputs;
|
std::vector<OutputDescription> outputs;
|
||||||
|
|
||||||
|
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||||
|
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||||
|
tf2xla::HostComputeMetadata host_compute_metadata;
|
||||||
|
|
||||||
// Resources whose values were updated by the computation, ordered
|
// Resources whose values were updated by the computation, ordered
|
||||||
// by return value position. Resource updates follow the non-constant
|
// by return value position. Resource updates follow the non-constant
|
||||||
// results in the outputs of XLA computation.
|
// results in the outputs of XLA computation.
|
||||||
@ -296,6 +301,22 @@ class XlaCompiler {
|
|||||||
// same XlaCompiler.
|
// same XlaCompiler.
|
||||||
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
||||||
|
|
||||||
|
// Sets the shapes and types for the device to host transfer associated with
|
||||||
|
// 'key'.
|
||||||
|
Status SetDeviceToHostMetadata(const string& key,
|
||||||
|
const std::vector<DataType>& types,
|
||||||
|
const std::vector<TensorShape>& shapes);
|
||||||
|
|
||||||
|
// Gets the shapes the device to host transfer associated with 'key'.
|
||||||
|
Status GetDeviceToHostShapes(const string& key,
|
||||||
|
std::vector<TensorShape>* shapes) const;
|
||||||
|
|
||||||
|
// Sets the shapes and types for the host to device transfer associated with
|
||||||
|
// 'key'.
|
||||||
|
Status SetHostToDeviceMetadata(const string& key,
|
||||||
|
const std::vector<DataType>& types,
|
||||||
|
const std::vector<TensorShape>& shapes);
|
||||||
|
|
||||||
const Options& options() const { return options_; }
|
const Options& options() const { return options_; }
|
||||||
xla::Client* client() const { return options_.client; }
|
xla::Client* client() const { return options_.client; }
|
||||||
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
|
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
|
||||||
@ -359,6 +380,9 @@ class XlaCompiler {
|
|||||||
|
|
||||||
std::unordered_map<string, xla::ChannelHandle> channels_;
|
std::unordered_map<string, xla::ChannelHandle> channels_;
|
||||||
|
|
||||||
|
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
|
||||||
|
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user