Change for asynchronous Send and Recv by splitting Send into {Send, SendDone}

and Recv into {Recv, RecvDone}. See operation_semantics.md for the updated
semantics.

PiperOrigin-RevId: 175216012
This commit is contained in:
HyoukJoong Lee 2017-11-09 14:48:37 -08:00 committed by TensorFlower Gardener
parent a0e9c52921
commit f3f85e9aa0
31 changed files with 550 additions and 47 deletions

View File

@ -819,17 +819,6 @@ Status BufferAssigner::AssignBuffersForComputation(
continue;
}
if (instruction->opcode() == HloOpcode::kRecv) {
// Make sure that recv operations get a new unique allocation so that
// don't share their buffer with any other operations.
BufferAllocation* allocation = assignment->NewAllocation(
*buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
allocation_indices.push_back(allocation->index());
VLOG(3) << "New allocation #" << allocation->index()
<< " for recv: " << *buffer;
continue;
}
if (ShapeUtil::IsTuple(buffer->shape())) {
// TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
// assumes longer buffer liveness than indicated by the analysis.

View File

@ -1983,6 +1983,11 @@ Status IrEmitter::HandleSend(HloInstruction* send) {
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
@ -2148,6 +2153,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) {
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandlePad(HloInstruction* pad) {
// CPU backend does not properly handle negative padding but this is ok
// because negative padding should be removed by the algebraic simplifier.

View File

@ -171,11 +171,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSendDone(HloInstruction* send_done) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override;
Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleMap(HloInstruction* map) override;

View File

@ -211,9 +211,11 @@ class DfsHloVisitorBase {
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
virtual Status HandleSend(HloInstructionPtr hlo) = 0;
virtual Status HandleSend(HloInstructionPtr send) = 0;
virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
virtual Status HandleRecv(HloInstructionPtr hlo) = 0;
virtual Status HandleRecv(HloInstructionPtr recv) = 0;
virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;

View File

@ -167,11 +167,17 @@ class DfsHloVisitorWithDefaultBase
Status HandleWhile(HloInstructionPtr xla_while) override {
return DefaultAction(xla_while);
}
Status HandleRecv(HloInstructionPtr recv) override {
return DefaultAction(recv);
}
Status HandleRecvDone(HloInstructionPtr recv_done) override {
return DefaultAction(recv_done);
}
Status HandleSend(HloInstructionPtr send) override {
return DefaultAction(send);
}
Status HandleRecv(HloInstructionPtr recv) override {
return DefaultAction(recv);
Status HandleSendDone(HloInstructionPtr send_done) override {
return DefaultAction(send_done);
}
// Invoked to inform the visitor that the traversal has completed, and that

View File

@ -128,10 +128,18 @@ Status IrEmitter::HandleSend(HloInstruction*) {
return Unimplemented("Send is not implemented on GPU");
}
Status IrEmitter::HandleSendDone(HloInstruction*) {
return Unimplemented("Send-Done is not implemented on GPU");
}
Status IrEmitter::HandleRecv(HloInstruction*) {
return Unimplemented("Recv is not implemented on GPU");
}
Status IrEmitter::HandleRecvDone(HloInstruction*) {
return Unimplemented("Recv-done is not implemented on GPU");
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) {

View File

@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleSort(HloInstruction* sort) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSendDone(HloInstruction* send_done) override;
Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;

View File

@ -337,10 +337,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
return Status::OK();
}

View File

@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleReducePrecision(const HloInstruction* hlo) override;
Status HandleConcatenate(const HloInstruction* concatenate) override;
Status HandleSend(const HloInstruction* send) override;
Status HandleSendDone(const HloInstruction* send_done) override;
Status HandleRecv(const HloInstruction* recv) override;
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
Status HandleDot(const HloInstruction* dot) override;

View File

@ -242,6 +242,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
return false;
}
bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
CHECK_EQ(send->opcode(), HloOpcode::kSend);
bool changed = false;
// Send forwards the operand value to the output tuple at {0}.
for (auto& pair : GetInstructionValueSet(send->operand(0))) {
const ShapeIndex& operand_index = pair.first;
const HloValueSet& operand_value_set = pair.second;
ShapeIndex index = {0};
for (int64 i : operand_index) {
index.push_back(i);
}
HloValueSet& value_set = GetValueSet(send, index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
bool changed = false;
// RecvDone forwards the operand value at {0} to the output.
for (auto& pair : GetInstructionValueSet(recv_done)) {
ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
ShapeIndex operand_index = {0};
for (int64 i : index) {
operand_index.push_back(i);
}
const HloValueSet& operand_value_set =
GetValueSet(recv_done->operand(0), operand_index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
CHECK_EQ(call->opcode(), HloOpcode::kCall);
InstructionValueSet& value_set = GetInstructionValueSet(call);
@ -429,6 +474,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateCallValueSet(instruction);
case HloOpcode::kWhile:
return UpdateWhileValueSet(instruction);
case HloOpcode::kSend:
return UpdateSendValueSet(instruction);
case HloOpcode::kRecvDone:
return UpdateRecvDoneValueSet(instruction);
default:
// Instruction does not forward HloValues (it defines all values in its
// output). No update is necessary.
@ -537,6 +586,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
GetValueSet(instruction, /*index=*/{}).AddValue(value);
};
// Lambda to set the value set at the given index of the output.
auto define_value_at = [this, &instruction](const ShapeIndex& index) {
HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
GetValueSet(instruction, index).AddValue(value);
};
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
if (bitcast_defines_value_) {
@ -577,6 +632,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values flow from their operands.
define_top_level_only();
break;
case HloOpcode::kRecvDone:
// RecvDone aliases its input tuple element {0}, therefore does not
// define any values.
break;
case HloOpcode::kSend:
// Send produces a tuple of {aliased operand, U32 context}, therefore
// only defines the top-level tuple and the tuple element at {1}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
break;
default:
define_all_values();
break;

View File

@ -146,7 +146,9 @@ class HloDataflowAnalysis {
bool UpdateCopyValueSet(HloInstruction* copy);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
bool UpdateSelectValueSet(HloInstruction* select);
bool UpdateSendValueSet(HloInstruction* send);
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);

View File

@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
// Test that a Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(param, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_EQ(analysis.values().size(), 4);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
}
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
// Test that a RecvDone forwards its operand tuple element at {0} to the
// output.
auto builder = HloComputation::Builder(TestName());
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_EQ(analysis.values().size(), 3);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
EXPECT_THAT(HloValuesAt(recv_done),
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
EXPECT_TRUE(
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
// A simple chain of elementwise operations. No values should interfere.
//

View File

@ -943,7 +943,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kFusion:
return kGray;
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
@ -1037,7 +1039,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
? ""
: StrCat("stride=", VectorString(instr->slice_strides()));
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
return StrCat("channel_id=", instr->channel_id());
default:
return "";

View File

@ -371,20 +371,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, int64 channel_id) {
// Send instruction produces a tuple of {aliased operand, U32 context}.
Shape output_shape = ShapeUtil::MakeTupleShape(
{operand->shape(), ShapeUtil::MakeShape(U32, {})});
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
instruction->AppendOperand(operand);
instruction->channel_id_ = channel_id;
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand) {
CHECK(operand->opcode() == HloOpcode::kSend)
<< "SendDone must take the context operand from Send";
auto instruction = WrapUnique(
new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
instruction->AppendOperand(operand);
instruction->channel_id_ = operand->channel_id();
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, int64 channel_id) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
// Recv instruction produces a tuple of {receive buffer, U32 context}.
Shape output_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
instruction->channel_id_ = channel_id;
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand) {
CHECK(operand->opcode() == HloOpcode::kRecv)
<< "RecvDone must take the context operand from Recv";
Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
instruction->AppendOperand(operand);
instruction->channel_id_ = operand->channel_id();
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
@ -908,7 +938,9 @@ RandomDistribution HloInstruction::random_distribution() const {
bool HloInstruction::HasSideEffect() const {
switch (opcode_) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
@ -1164,7 +1196,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[4], epsilon(), feature_index());
break;
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@ -1557,8 +1591,10 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
return false;
}
}
@ -1891,7 +1927,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
})));
}
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
extra.push_back(StrCat("channel_id=", channel_id_));
}
@ -2071,8 +2108,10 @@ bool HloInstruction::IsFusable() const {
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kTrace:
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. If it is the root (user count = 0)
@ -2279,10 +2318,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCall(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
case HloOpcode::kSend:
return visitor->HandleSend(this);
case HloOpcode::kRecv:
return visitor->HandleRecv(this);
case HloOpcode::kRecvDone:
return visitor->HandleRecvDone(this);
case HloOpcode::kSend:
return visitor->HandleSend(this);
case HloOpcode::kSendDone:
return visitor->HandleSendDone(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:

View File

@ -181,18 +181,28 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config);
// Creates a send instruction with the given channel id, which sends the
// operand data to a unique receive instruction in another computation that
// has the same channel id.
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
// another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
int64 channel_id);
// Creates a receive instruction with the given channel id, which receives
// data of the given shape from a unique send instruction in another
// computation that has the same channel id.
// Blocks until data transfer for the Send instruction (operand) is complete.
// The operand must be kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand);
// Creates an asynchronous receive instruction with the given channel id,
// which allocates resources to receive data of the given shape from a unique
// send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
int64 channel_id);
// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand);
// Creates a slice instruction, where the operand is sliced by the given
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(

View File

@ -121,6 +121,7 @@ HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad);
HLO_MATCHER(Power);
HLO_MATCHER(Recv);
HLO_MATCHER(RecvDone);
HLO_MATCHER(Reduce);
HLO_MATCHER(ReducePrecision);
HLO_MATCHER(ReduceWindow);
@ -131,6 +132,7 @@ HLO_MATCHER(Rng);
HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send);
HLO_MATCHER(SendDone);
HLO_MATCHER(ShiftLeft);
HLO_MATCHER(ShiftRightLogical);
HLO_MATCHER(ShiftRightArithmetic);

View File

@ -97,6 +97,7 @@ namespace xla {
V(kPower, "power") \
V(kReal, "real") \
V(kRecv, "recv") \
V(kRecvDone, "recv-done") \
V(kReduce, "reduce") \
V(kReducePrecision, "reduce-precision") \
V(kReduceWindow, "reduce-window") \
@ -108,6 +109,7 @@ namespace xla {
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \
V(kSendDone, "send-done") \
V(kShiftLeft, "shift-left") \
V(kShiftRightArithmetic, "shift-right-arithmetic") \
V(kShiftRightLogical, "shift-right-logical") \

View File

@ -66,7 +66,9 @@ bool IsRematerializable(const HloInstruction* instruction) {
case HloOpcode::kInfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
return false;

View File

@ -270,12 +270,40 @@ class ShapeVerifier : public DfsHloVisitor {
pad->padding_config()));
}
Status HandleSend(HloInstruction*) override {
return tensorflow::Status::OK();
Status HandleSend(HloInstruction* send) override {
TF_RET_CHECK(send->users().size() == 1);
const HloInstruction* send_done = send->users()[0];
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
return CheckShape(
send, ShapeUtil::MakeTupleShape(
{send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
}
Status HandleRecv(HloInstruction*) override {
return tensorflow::Status::OK();
Status HandleSendDone(HloInstruction* send_done) override {
TF_RET_CHECK(send_done->operands().size() == 1);
const HloInstruction* send = send_done->operand(0);
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
return CheckShape(send_done, ShapeUtil::MakeNil());
}
Status HandleRecv(HloInstruction* recv) override {
TF_RET_CHECK(recv->users().size() == 1);
const HloInstruction* recv_done = recv->users()[0];
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
return CheckShape(recv,
ShapeUtil::MakeTupleShape(
{recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
}
Status HandleRecvDone(HloInstruction* recv_done) override {
TF_RET_CHECK(recv_done->operands().size() == 1);
const HloInstruction* recv = recv_done->operand(0);
TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
return CheckShape(recv_done, recv->shape().tuple_shapes(0));
}
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
@ -365,6 +393,19 @@ class ShapeVerifier : public DfsHloVisitor {
instruction->opcode(), instruction->operands()));
}
// Checks if the given two instructions shares the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
if (instr1->channel_id() != instr2->channel_id()) {
return FailedPrecondition(
"Expected to have the same channel id, actual channel ids are: %s "
"(%lld), %s (%lld)",
instr1->ToString().c_str(), instr1->channel_id(),
instr2->ToString().c_str(), instr2->channel_id());
}
return tensorflow::Status::OK();
}
// Returns the size of a Shape in bytes.
const std::function<int64(const Shape&)> shape_size_fn_;
};

View File

@ -113,7 +113,9 @@ namespace xla {
case HloOpcode::kTrace:
case HloOpcode::kWhile:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
return true;
}

View File

@ -104,6 +104,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
return Status::OK();
}
Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
// RecvDone doesn't create a new buffer but rather aliases its input (Recv)
// tuple element at {0} to its output.
return Status::OK();
}
Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
// Send creates new buffers for the top-level tuple and the context (tuple
// element at {1}). Tuple element at {0} is an alias of the Send operand, so
// we don't need to create a new Logical Buffer for that.
NewLogicalBuffer(send, /*index=*/{});
NewLogicalBuffer(send, /*index=*/{1});
return Status::OK();
}
Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
// A Tuple instruction only creates the top-level buffer.
NewLogicalBuffer(tuple, /*index=*/{});

View File

@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override;
// A map from the buffer ID to the logical buffer

View File

@ -253,6 +253,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// RecvDone aliases its input (Recv) tuple element {0} to its output.
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
const PointsToSet& operand_points_to_set =
GetPointsToSet(recv_done->operand(0));
// Recursively copy the points to set of the operand tuple {0}.
points_to_set.ForEachMutableElement(
[this, &points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
ShapeIndex src_index({0});
for (auto element : index) {
src_index.push_back(element);
}
*buffers = operand_points_to_set.element(src_index);
for (auto& tuple_source :
operand_points_to_set.tuple_sources(src_index)) {
points_to_set.add_tuple_source(index, tuple_source);
}
});
return Status::OK();
}
Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
// Send creates a tuple of {aliased operand, U32 context}.
PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
// Creates the points to set for the tuple and its element at {1}.
auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
top_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
points_to_set.add_tuple_source({}, send);
auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
context_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
// Recursively copy the points to set of the operand to output tuple {0}.
const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
operand_points_to_set.ForEachElement(
[&points_to_set, &operand_points_to_set](
const ShapeIndex& src_index,
const PointsToSet::BufferList& points_to) {
ShapeIndex target_index({0});
for (auto element : src_index) {
target_index.push_back(element);
}
*points_to_set.mutable_element(target_index) = points_to;
for (HloInstruction* tuple :
operand_points_to_set.tuple_sources(src_index)) {
points_to_set.add_tuple_source(target_index, tuple);
}
});
return Status::OK();
}
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);

View File

@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override;
string ToString() const;

View File

@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
{constant1, constant2, copy});
}
TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
// Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(constant, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send).element({}), {send});
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
{send_done});
ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
}
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
}
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// Select from two different tuples. This should create an ambiguous points to
// set containing the union of both sides.

View File

@ -2927,8 +2927,9 @@ void ComputationLowerer::Visit(
case OpRequest::kRecvRequest: {
const RecvRequest& recv_request = request.request().recv_request();
hlo_instruction = add_instruction(HloInstruction::CreateRecv(
HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
request.output_shape(), recv_request.channel_handle().handle()));
hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
break;
}
@ -3120,8 +3121,9 @@ void ComputationLowerer::Visit(
case OpRequest::kSendRequest: {
const SendRequest& send_request = request.request().send_request();
HloInstruction* operand = lookup_instruction(send_request.operand());
hlo_instruction = add_instruction(HloInstruction::CreateSend(
HloInstruction* send = add_instruction(HloInstruction::CreateSend(
operand, send_request.channel_handle().handle()));
hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
break;
}

View File

@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) {
static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
if (instr->opcode() == HloOpcode::kSend ||
instr->opcode() == HloOpcode::kRecv) {
instr->opcode() == HloOpcode::kSendDone ||
instr->opcode() == HloOpcode::kRecv ||
instr->opcode() == HloOpcode::kRecvDone) {
return true;
}
for (const auto& subcomp : instr->called_computations()) {

View File

@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
while_body->AddInstruction(HloInstruction::CreateSend(
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
while_body->AddInstruction(
auto* recv = while_body->AddInstruction(
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}

View File

@ -442,7 +442,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRecv(shape, *channel_id));
HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
break;
}
case HloOpcode::kRecvDone: {
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
break;
}
case HloOpcode::kSend: {
@ -456,6 +470,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateSend(operands[0], *channel_id));
break;
}
case HloOpcode::kSendDone: {
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
break;
}
case HloOpcode::kGetTupleElement: {
optional<int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};

View File

@ -226,9 +226,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
R"(HloModule TwoSendRecvBothWayRecvFist_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
%recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
%constant = f32[] constant(2.1), sharding={maximal device=0}
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
%send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
@ -522,9 +524,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
const string original = R"(HloModule unexpected_attr_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
%recv = (f32[], u32[]) recv(), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = () send(f32[] %constant), channel_id=16, calls=%recv
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
)";
@ -536,9 +540,11 @@ TEST_F(HloParserTest, MissingAttribute) {
const string original = R"(HloModule missing_attr_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
%recv = (f32[], u32[]) recv(), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(-2.1)
%send = () send(f32[] %constant)
%send = (f32[], u32[]) send(f32[] %constant)
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
)";
@ -550,9 +556,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
const string original = R"(HloModule pre_not_found_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
%recv = (f32[], u32[]) recv(), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
)";

View File

@ -901,6 +901,95 @@ are all 0. Figure below shows examples of different `edge_padding` and
<img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
</div>
## Recv
See also
[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
<b> `Recv(shape, channel_handle)` </b>
| Arguments | Type | Semantics |
| ---------------- | --------------- | ------------------------------------ |
| `shape` | `Shape` | shape of the data to receive |
| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
Receives data of the given shape from a `Send` instruction in another
computation that shares the same channel handle. Returns a
ComputationDataHandle for the received data.
The client API of `Recv` operation represents synchronous communication.
However, the instruction is internally decomposed into 2 HLO instructions
(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
<b>`Recv(const Shape& shape, int64 channel_id)`</b>
Allocates resources required to receive data from a `Send` instruction with the
same channel_id. Returns a context for the allocated resources, which is used
by a following `RecvDone` instruction to wait for the completion of the data
transfer. The context is a tuple of {receive buffer (shape), request identifier
(U32)} and it can only be used by a `RecvDone` instruction.
<b> `RecvDone(HloInstruction context)` </b>
Given a context created by a `Recv` instruction, waits for the data transfer to
complete and returns the received data.
## Send
See also
[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
<b> `Send(operand, channel_handle)` </b>
| Arguments | Type | Semantics |
| ---------------- | ----------------------- | -------------------------------- |
| `operand` | `ComputationDataHandle` | data to send (array of type T) |
| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
Sends the given operand data to a `Recv` instruction in another computation
that shares the same channel handle. Does not return any data.
Similar to the `Recv` operation, the client API of `Send` operation represents
synchronous communication, and is internally decomposed into 2 HLO instructions
(`Send` and `SendDone`) to enable asynchronous data transfers. See also
[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
<b>`Send(HloInstruction operand, int64 channel_id)`</b>
Initiates an asynchronous transfer of the operand to the resources allocated by
the `Recv` instruction with the same channel id. Returns a context, which is
used by a following `SendDone` instruction to wait for the completion of the
data transfer. The context is a tuple of {operand (shape), request identifier
(U32)} and it can only be used by a `SendDone` instruction.
<b> `SendDone(HloInstruction context)` </b>
Given a context created by a `Send` instruction, waits for the data transfer to
complete. The instruction does not return any data.
<b> Scheduling of channel instructions </b>
The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
`Send`, `SendDone`) is as below.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:70%" src="../../images/send_recv_order.png">
</div>
* `Recv` happens before `Send`
* `Send` happens before `RecvDone`
* `Recv` happens before `RecvDone`
* `Send` happens before `SendDone`
When the backend compilers generate a linear schedule for each computation that
communicates via channel instructions, there must not be cycles across the
computations. For example, below schedules lead to deadlocks.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/send_recv_schedule.png">
</div>
## Reduce
See also