In HostCompute op, use SendToHost/RecvFromHost instead of Send/Recv.

PiperOrigin-RevId: 209617148
This commit is contained in:
Tong Shen 2018-08-21 10:25:52 -07:00 committed by TensorFlower Gardener
parent 792a933b11
commit 938a3b7779
6 changed files with 52 additions and 12 deletions

View File

@ -792,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
// Copy the host transfer metadata to the result.
for (const auto& send : host_compute_sends_) {
*result->host_compute_metadata.add_device_to_host() = send.second;
}
for (const auto& recv : host_compute_recvs_) {
*result->host_compute_metadata.add_host_to_device() = recv.second;
}
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
@ -817,6 +809,30 @@ Status XlaCompiler::GetChannelHandle(const string& key,
return Status::OK();
}
Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
xla::ChannelHandle* channel) {
auto result = channels_.emplace(key, xla::ChannelHandle());
if (result.second) {
TF_ASSIGN_OR_RETURN(result.first->second,
client()->CreateHostToDeviceChannelHandle());
}
*channel = result.first->second;
VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
return Status::OK();
}
Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
xla::ChannelHandle* channel) {
auto result = channels_.emplace(key, xla::ChannelHandle());
if (result.second) {
TF_ASSIGN_OR_RETURN(result.first->second,
client()->CreateDeviceToHostChannelHandle());
}
*channel = result.first->second;
VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
return Status::OK();
}
namespace {
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,

View File

@ -332,6 +332,16 @@ class XlaCompiler {
// same XlaCompiler.
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
// Retrieves the host-to-device channel handle associated with `key`.
// Allocates a new channel handle if none exists.
Status GetHostToDeviceChannelHandle(const string& key,
xla::ChannelHandle* channel);
// Retrieves the device-to-host channel handle associated with `key`.
// Allocates a new channel handle if none exists.
Status GetDeviceToHostChannelHandle(const string& key,
xla::ChannelHandle* channel);
// Sets the shapes and types for the device to host transfer associated with
// 'key'.
Status SetDeviceToHostMetadata(const string& key,

View File

@ -1098,6 +1098,7 @@ cc_library(
hdrs = ["hlo_module_group_util.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_module_group_metadata",
":hlo_reachability",
"//tensorflow/compiler/xla:status",

View File

@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
return channels_[channel_id_map_.at(channel_id)];
}
bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
return channel_id_map_.find(channel_id) != channel_id_map_.end();
}
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));

View File

@ -125,6 +125,9 @@ class HloModuleGroupMetadata {
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
// Returns if the given channel id exists in metadata.
bool HasChannel(int64 channel_id) const;
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;

View File

@ -23,6 +23,8 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
add_unique_predecessor(control_predecessor);
}
}
if (instruction->opcode() == HloOpcode::kRecvDone) {
if (instruction->opcode() == HloOpcode::kRecvDone &&
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_predecessor(send);
}
if (instruction->opcode() == HloOpcode::kSend) {
if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
add_unique_successor(control_successor);
}
}
if (instruction->opcode() == HloOpcode::kRecv) {
if (instruction->opcode() == HloOpcode::kRecv &&
!DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_successor(send);
}
if (instruction->opcode() == HloOpcode::kSend) {
if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;