In HostCompute op, use SendToHost/RecvFromHost instead of Send/Recv.
PiperOrigin-RevId: 209617148
This commit is contained in:
parent
792a933b11
commit
938a3b7779
@ -792,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
VLOG(2) << "XLA output shape: "
|
||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||
|
||||
// Copy the host transfer metadata to the result.
|
||||
for (const auto& send : host_compute_sends_) {
|
||||
*result->host_compute_metadata.add_device_to_host() = send.second;
|
||||
}
|
||||
for (const auto& recv : host_compute_recvs_) {
|
||||
*result->host_compute_metadata.add_host_to_device() = recv.second;
|
||||
}
|
||||
|
||||
// Tensorflow expects a major-to-minor order of results.
|
||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||
|
||||
@ -817,6 +809,30 @@ Status XlaCompiler::GetChannelHandle(const string& key,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
|
||||
xla::ChannelHandle* channel) {
|
||||
auto result = channels_.emplace(key, xla::ChannelHandle());
|
||||
if (result.second) {
|
||||
TF_ASSIGN_OR_RETURN(result.first->second,
|
||||
client()->CreateHostToDeviceChannelHandle());
|
||||
}
|
||||
*channel = result.first->second;
|
||||
VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
|
||||
xla::ChannelHandle* channel) {
|
||||
auto result = channels_.emplace(key, xla::ChannelHandle());
|
||||
if (result.second) {
|
||||
TF_ASSIGN_OR_RETURN(result.first->second,
|
||||
client()->CreateDeviceToHostChannelHandle());
|
||||
}
|
||||
*channel = result.first->second;
|
||||
VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
|
||||
|
@ -332,6 +332,16 @@ class XlaCompiler {
|
||||
// same XlaCompiler.
|
||||
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
||||
|
||||
// Retrieves the host-to-device channel handle associated with `key`.
|
||||
// Allocates a new channel handle if none exists.
|
||||
Status GetHostToDeviceChannelHandle(const string& key,
|
||||
xla::ChannelHandle* channel);
|
||||
|
||||
// Retrieves the device-to-host channel handle associated with `key`.
|
||||
// Allocates a new channel handle if none exists.
|
||||
Status GetDeviceToHostChannelHandle(const string& key,
|
||||
xla::ChannelHandle* channel);
|
||||
|
||||
// Sets the shapes and types for the device to host transfer associated with
|
||||
// 'key'.
|
||||
Status SetDeviceToHostMetadata(const string& key,
|
||||
|
@ -1098,6 +1098,7 @@ cc_library(
|
||||
hdrs = ["hlo_module_group_util.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_module_group_metadata",
|
||||
":hlo_reachability",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
|
@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
|
||||
return channels_[channel_id_map_.at(channel_id)];
|
||||
}
|
||||
|
||||
bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
|
||||
return channel_id_map_.find(channel_id) != channel_id_map_.end();
|
||||
}
|
||||
|
||||
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
||||
const HloInstruction* instruction) const {
|
||||
CHECK(IsChannelInstruction(instruction));
|
||||
|
@ -125,6 +125,9 @@ class HloModuleGroupMetadata {
|
||||
// Returns the Channel instance for the given channel id.
|
||||
const Channel& GetChannel(int64 channel_id) const;
|
||||
|
||||
// Returns if the given channel id exists in metadata.
|
||||
bool HasChannel(int64 channel_id) const;
|
||||
|
||||
// Returns the all-reduce instructions with the same all_reduce_id.
|
||||
const std::vector<HloInstruction*>& GetAllReduceGroup(
|
||||
int64 all_reduce_id) const;
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
||||
add_unique_predecessor(control_predecessor);
|
||||
}
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kRecvDone) {
|
||||
if (instruction->opcode() == HloOpcode::kRecvDone &&
|
||||
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
|
||||
// Send is a remote predecessor of RecvDone.
|
||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||
add_unique_predecessor(send);
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kSend) {
|
||||
if (instruction->opcode() == HloOpcode::kSend &&
|
||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||
// Recv is a remote predecessor of Send.
|
||||
HloInstruction* recv_done =
|
||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||
@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
||||
add_unique_successor(control_successor);
|
||||
}
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kRecv) {
|
||||
if (instruction->opcode() == HloOpcode::kRecv &&
|
||||
!DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
|
||||
// Send is a remote successor of Recv.
|
||||
const HloInstruction* recv_done = instruction->users().front();
|
||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||
add_unique_successor(send);
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kSend) {
|
||||
if (instruction->opcode() == HloOpcode::kSend &&
|
||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||
// RecvDone is a remote successor of Send.
|
||||
HloInstruction* recv_done =
|
||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||
|
Loading…
Reference in New Issue
Block a user