In HostCompute op, use SendToHost/RecvFromHost instead of Send/Recv.
PiperOrigin-RevId: 209617148
This commit is contained in:
parent
792a933b11
commit
938a3b7779
@ -792,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
VLOG(2) << "XLA output shape: "
|
VLOG(2) << "XLA output shape: "
|
||||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||||
|
|
||||||
// Copy the host transfer metadata to the result.
|
|
||||||
for (const auto& send : host_compute_sends_) {
|
|
||||||
*result->host_compute_metadata.add_device_to_host() = send.second;
|
|
||||||
}
|
|
||||||
for (const auto& recv : host_compute_recvs_) {
|
|
||||||
*result->host_compute_metadata.add_host_to_device() = recv.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tensorflow expects a major-to-minor order of results.
|
// Tensorflow expects a major-to-minor order of results.
|
||||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||||
|
|
||||||
@ -817,6 +809,30 @@ Status XlaCompiler::GetChannelHandle(const string& key,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
|
||||||
|
xla::ChannelHandle* channel) {
|
||||||
|
auto result = channels_.emplace(key, xla::ChannelHandle());
|
||||||
|
if (result.second) {
|
||||||
|
TF_ASSIGN_OR_RETURN(result.first->second,
|
||||||
|
client()->CreateHostToDeviceChannelHandle());
|
||||||
|
}
|
||||||
|
*channel = result.first->second;
|
||||||
|
VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
|
||||||
|
xla::ChannelHandle* channel) {
|
||||||
|
auto result = channels_.emplace(key, xla::ChannelHandle());
|
||||||
|
if (result.second) {
|
||||||
|
TF_ASSIGN_OR_RETURN(result.first->second,
|
||||||
|
client()->CreateDeviceToHostChannelHandle());
|
||||||
|
}
|
||||||
|
*channel = result.first->second;
|
||||||
|
VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
|
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
|
||||||
|
@ -332,6 +332,16 @@ class XlaCompiler {
|
|||||||
// same XlaCompiler.
|
// same XlaCompiler.
|
||||||
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
||||||
|
|
||||||
|
// Retrieves the host-to-device channel handle associated with `key`.
|
||||||
|
// Allocates a new channel handle if none exists.
|
||||||
|
Status GetHostToDeviceChannelHandle(const string& key,
|
||||||
|
xla::ChannelHandle* channel);
|
||||||
|
|
||||||
|
// Retrieves the device-to-host channel handle associated with `key`.
|
||||||
|
// Allocates a new channel handle if none exists.
|
||||||
|
Status GetDeviceToHostChannelHandle(const string& key,
|
||||||
|
xla::ChannelHandle* channel);
|
||||||
|
|
||||||
// Sets the shapes and types for the device to host transfer associated with
|
// Sets the shapes and types for the device to host transfer associated with
|
||||||
// 'key'.
|
// 'key'.
|
||||||
Status SetDeviceToHostMetadata(const string& key,
|
Status SetDeviceToHostMetadata(const string& key,
|
||||||
|
@ -1098,6 +1098,7 @@ cc_library(
|
|||||||
hdrs = ["hlo_module_group_util.h"],
|
hdrs = ["hlo_module_group_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":hlo_casting_utils",
|
||||||
":hlo_module_group_metadata",
|
":hlo_module_group_metadata",
|
||||||
":hlo_reachability",
|
":hlo_reachability",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
|
@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
|
|||||||
return channels_[channel_id_map_.at(channel_id)];
|
return channels_[channel_id_map_.at(channel_id)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
|
||||||
|
return channel_id_map_.find(channel_id) != channel_id_map_.end();
|
||||||
|
}
|
||||||
|
|
||||||
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
||||||
const HloInstruction* instruction) const {
|
const HloInstruction* instruction) const {
|
||||||
CHECK(IsChannelInstruction(instruction));
|
CHECK(IsChannelInstruction(instruction));
|
||||||
|
@ -125,6 +125,9 @@ class HloModuleGroupMetadata {
|
|||||||
// Returns the Channel instance for the given channel id.
|
// Returns the Channel instance for the given channel id.
|
||||||
const Channel& GetChannel(int64 channel_id) const;
|
const Channel& GetChannel(int64 channel_id) const;
|
||||||
|
|
||||||
|
// Returns if the given channel id exists in metadata.
|
||||||
|
bool HasChannel(int64 channel_id) const;
|
||||||
|
|
||||||
// Returns the all-reduce instructions with the same all_reduce_id.
|
// Returns the all-reduce instructions with the same all_reduce_id.
|
||||||
const std::vector<HloInstruction*>& GetAllReduceGroup(
|
const std::vector<HloInstruction*>& GetAllReduceGroup(
|
||||||
int64 all_reduce_id) const;
|
int64 all_reduce_id) const;
|
||||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
|||||||
add_unique_predecessor(control_predecessor);
|
add_unique_predecessor(control_predecessor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kRecvDone) {
|
if (instruction->opcode() == HloOpcode::kRecvDone &&
|
||||||
|
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
|
||||||
// Send is a remote predecessor of RecvDone.
|
// Send is a remote predecessor of RecvDone.
|
||||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||||
add_unique_predecessor(send);
|
add_unique_predecessor(send);
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kSend) {
|
if (instruction->opcode() == HloOpcode::kSend &&
|
||||||
|
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||||
// Recv is a remote predecessor of Send.
|
// Recv is a remote predecessor of Send.
|
||||||
HloInstruction* recv_done =
|
HloInstruction* recv_done =
|
||||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||||
@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
|||||||
add_unique_successor(control_successor);
|
add_unique_successor(control_successor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kRecv) {
|
if (instruction->opcode() == HloOpcode::kRecv &&
|
||||||
|
!DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
|
||||||
// Send is a remote successor of Recv.
|
// Send is a remote successor of Recv.
|
||||||
const HloInstruction* recv_done = instruction->users().front();
|
const HloInstruction* recv_done = instruction->users().front();
|
||||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||||
add_unique_successor(send);
|
add_unique_successor(send);
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kSend) {
|
if (instruction->opcode() == HloOpcode::kSend &&
|
||||||
|
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||||
// RecvDone is a remote successor of Send.
|
// RecvDone is a remote successor of Send.
|
||||||
HloInstruction* recv_done =
|
HloInstruction* recv_done =
|
||||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||||
|
Loading…
Reference in New Issue
Block a user