diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 43ff5fcef89..ac1deae4b20 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -792,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); - // Copy the host transfer metadata to the result. - for (const auto& send : host_compute_sends_) { - *result->host_compute_metadata.add_device_to_host() = send.second; - } - for (const auto& recv : host_compute_recvs_) { - *result->host_compute_metadata.add_host_to_device() = recv.second; - } - // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); @@ -817,6 +809,30 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateHostToDeviceChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateDeviceToHostChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + namespace { void SetTransfer(const string& key, gtl::ArraySlice types, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 25332c8d8e3..fde47dbdec8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -332,6 +332,16 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Retrieves the host-to-device channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel); + + // Retrieves the device-to-host channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 01f273ad1f7..7fdffe85c0e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1098,6 +1098,7 @@ cc_library( hdrs = ["hlo_module_group_util.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_group_metadata", ":hlo_reachability", "//tensorflow/compiler/xla:status", diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 3b512bf0f81..cd10913763c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( return channels_[channel_id_map_.at(channel_id)]; } +bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const { + return channel_id_map_.find(channel_id) != channel_id_map_.end(); +} + HloComputation* HloModuleGroupMetadata::PeerComputation( const HloInstruction* instruction) const { CHECK(IsChannelInstruction(instruction)); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 1b256cd00e6..924c8fda71a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -125,6 +125,9 @@ class HloModuleGroupMetadata { // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns if the given channel id exists in metadata. + bool HasChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. const std::vector& GetAllReduceGroup( int64 all_reduce_id) const; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 4f11ce322e6..1a4da388e4a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -94,12 +96,14 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( add_unique_predecessor(control_predecessor); } } - if (instruction->opcode() == HloOpcode::kRecvDone) { + if (instruction->opcode() == HloOpcode::kRecvDone && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote predecessor of RecvDone. HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_predecessor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // Recv is a remote predecessor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -170,14 +174,16 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( add_unique_successor(control_successor); } } - if (instruction->opcode() == HloOpcode::kRecv) { + if (instruction->opcode() == HloOpcode::kRecv && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote successor of Recv. const HloInstruction* recv_done = instruction->users().front(); CHECK(recv_done->opcode() == HloOpcode::kRecvDone); HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_successor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // RecvDone is a remote successor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done;