Change for asynchronous Send and Recv by splitting Send into {Send, SendDone}
and Recv into {Recv, RecvDone}. See operation_semantics.md for the updated semantics. PiperOrigin-RevId: 175216012
This commit is contained in:
parent
a0e9c52921
commit
f3f85e9aa0
tensorflow
compiler/xla
service
buffer_assignment.cc
cpu
dfs_hlo_visitor.hdfs_hlo_visitor_with_default.hgpu
hlo_cost_analysis.cchlo_cost_analysis.hhlo_dataflow_analysis.cchlo_dataflow_analysis.hhlo_dataflow_analysis_test.cchlo_graph_dumper.cchlo_instruction.cchlo_instruction.hhlo_matchers.hhlo_opcode.hhlo_rematerialization.cchlo_verifier.ccinstruction_fusion.cclogical_buffer_analysis.cclogical_buffer_analysis.htuple_points_to_analysis.cctuple_points_to_analysis.htuple_points_to_analysis_test.ccuser_computation.ccwhile_loop_simplifier.ccwhile_loop_simplifier_test.cctools/parser
docs_src/performance/xla
@ -819,17 +819,6 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kRecv) {
|
||||
// Make sure that recv operations get a new unique allocation so that
|
||||
// don't share their buffer with any other operations.
|
||||
BufferAllocation* allocation = assignment->NewAllocation(
|
||||
*buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
|
||||
allocation_indices.push_back(allocation->index());
|
||||
VLOG(3) << "New allocation #" << allocation->index()
|
||||
<< " for recv: " << *buffer;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ShapeUtil::IsTuple(buffer->shape())) {
|
||||
// TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
|
||||
// assumes longer buffer liveness than indicated by the analysis.
|
||||
|
@ -1983,6 +1983,11 @@ Status IrEmitter::HandleSend(HloInstruction* send) {
|
||||
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||
VLOG(2) << "HandleSlice: " << slice->ToString();
|
||||
auto operand = slice->operand(0);
|
||||
@ -2148,6 +2153,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) {
|
||||
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandlePad(HloInstruction* pad) {
|
||||
// CPU backend does not properly handle negative padding but this is ok
|
||||
// because negative padding should be removed by the algebraic simplifier.
|
||||
|
@ -171,11 +171,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window) override;
|
||||
Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
Status HandleSendDone(HloInstruction* send_done) override;
|
||||
Status HandleSlice(HloInstruction* slice) override;
|
||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
|
||||
Status HandleDynamicUpdateSlice(
|
||||
HloInstruction* dynamic_update_slice) override;
|
||||
Status HandleRecv(HloInstruction* recv) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandlePad(HloInstruction* pad) override;
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
Status HandleMap(HloInstruction* map) override;
|
||||
|
@ -211,9 +211,11 @@ class DfsHloVisitorBase {
|
||||
|
||||
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
|
||||
|
||||
virtual Status HandleSend(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleSend(HloInstructionPtr send) = 0;
|
||||
virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
|
||||
|
||||
virtual Status HandleRecv(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleRecv(HloInstructionPtr recv) = 0;
|
||||
virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
|
||||
|
||||
virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
|
||||
|
||||
|
@ -167,11 +167,17 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleWhile(HloInstructionPtr xla_while) override {
|
||||
return DefaultAction(xla_while);
|
||||
}
|
||||
Status HandleRecv(HloInstructionPtr recv) override {
|
||||
return DefaultAction(recv);
|
||||
}
|
||||
Status HandleRecvDone(HloInstructionPtr recv_done) override {
|
||||
return DefaultAction(recv_done);
|
||||
}
|
||||
Status HandleSend(HloInstructionPtr send) override {
|
||||
return DefaultAction(send);
|
||||
}
|
||||
Status HandleRecv(HloInstructionPtr recv) override {
|
||||
return DefaultAction(recv);
|
||||
Status HandleSendDone(HloInstructionPtr send_done) override {
|
||||
return DefaultAction(send_done);
|
||||
}
|
||||
|
||||
// Invoked to inform the visitor that the traversal has completed, and that
|
||||
|
@ -128,10 +128,18 @@ Status IrEmitter::HandleSend(HloInstruction*) {
|
||||
return Unimplemented("Send is not implemented on GPU");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSendDone(HloInstruction*) {
|
||||
return Unimplemented("Send-Done is not implemented on GPU");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleRecv(HloInstruction*) {
|
||||
return Unimplemented("Recv is not implemented on GPU");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleRecvDone(HloInstruction*) {
|
||||
return Unimplemented("Recv-done is not implemented on GPU");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
std::vector<llvm::Value*> base_ptrs;
|
||||
for (const HloInstruction* operand : tuple->operands()) {
|
||||
|
@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
Status HandleOutfeed(HloInstruction* outfeed) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
Status HandleSendDone(HloInstruction* send_done) override;
|
||||
Status HandleRecv(HloInstruction* recv) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleParameter(HloInstruction* parameter) override;
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
|
@ -337,10 +337,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleReducePrecision(const HloInstruction* hlo) override;
|
||||
Status HandleConcatenate(const HloInstruction* concatenate) override;
|
||||
Status HandleSend(const HloInstruction* send) override;
|
||||
Status HandleSendDone(const HloInstruction* send_done) override;
|
||||
Status HandleRecv(const HloInstruction* recv) override;
|
||||
Status HandleRecvDone(const HloInstruction* recv_done) override;
|
||||
Status HandleConvert(const HloInstruction* convert) override;
|
||||
Status HandleCopy(const HloInstruction* copy) override;
|
||||
Status HandleDot(const HloInstruction* dot) override;
|
||||
|
@ -242,6 +242,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
|
||||
CHECK_EQ(send->opcode(), HloOpcode::kSend);
|
||||
bool changed = false;
|
||||
// Send forwards the operand value to the output tuple at {0}.
|
||||
for (auto& pair : GetInstructionValueSet(send->operand(0))) {
|
||||
const ShapeIndex& operand_index = pair.first;
|
||||
const HloValueSet& operand_value_set = pair.second;
|
||||
|
||||
ShapeIndex index = {0};
|
||||
for (int64 i : operand_index) {
|
||||
index.push_back(i);
|
||||
}
|
||||
|
||||
HloValueSet& value_set = GetValueSet(send, index);
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
|
||||
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
|
||||
bool changed = false;
|
||||
// RecvDone forwards the operand value at {0} to the output.
|
||||
for (auto& pair : GetInstructionValueSet(recv_done)) {
|
||||
ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
|
||||
ShapeIndex operand_index = {0};
|
||||
for (int64 i : index) {
|
||||
operand_index.push_back(i);
|
||||
}
|
||||
|
||||
const HloValueSet& operand_value_set =
|
||||
GetValueSet(recv_done->operand(0), operand_index);
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
|
||||
CHECK_EQ(call->opcode(), HloOpcode::kCall);
|
||||
InstructionValueSet& value_set = GetInstructionValueSet(call);
|
||||
@ -429,6 +474,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
return UpdateCallValueSet(instruction);
|
||||
case HloOpcode::kWhile:
|
||||
return UpdateWhileValueSet(instruction);
|
||||
case HloOpcode::kSend:
|
||||
return UpdateSendValueSet(instruction);
|
||||
case HloOpcode::kRecvDone:
|
||||
return UpdateRecvDoneValueSet(instruction);
|
||||
default:
|
||||
// Instruction does not forward HloValues (it defines all values in its
|
||||
// output). No update is necessary.
|
||||
@ -537,6 +586,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
GetValueSet(instruction, /*index=*/{}).AddValue(value);
|
||||
};
|
||||
|
||||
// Lambda to set the value set at the given index of the output.
|
||||
auto define_value_at = [this, &instruction](const ShapeIndex& index) {
|
||||
HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
|
||||
GetValueSet(instruction, index).AddValue(value);
|
||||
};
|
||||
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
if (bitcast_defines_value_) {
|
||||
@ -577,6 +632,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// values flow from their operands.
|
||||
define_top_level_only();
|
||||
break;
|
||||
case HloOpcode::kRecvDone:
|
||||
// RecvDone aliases its input tuple element {0}, therefore does not
|
||||
// define any values.
|
||||
break;
|
||||
case HloOpcode::kSend:
|
||||
// Send produces a tuple of {aliased operand, U32 context}, therefore
|
||||
// only defines the top-level tuple and the tuple element at {1}.
|
||||
define_value_at(/*index=*/{});
|
||||
define_value_at(/*index=*/{1});
|
||||
break;
|
||||
default:
|
||||
define_all_values();
|
||||
break;
|
||||
|
@ -146,7 +146,9 @@ class HloDataflowAnalysis {
|
||||
bool UpdateCopyValueSet(HloInstruction* copy);
|
||||
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
|
||||
bool UpdateParameterValueSet(HloInstruction* parameter);
|
||||
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
|
||||
bool UpdateSelectValueSet(HloInstruction* select);
|
||||
bool UpdateSendValueSet(HloInstruction* send);
|
||||
bool UpdateTupleValueSet(HloInstruction* tuple);
|
||||
bool UpdateWhileValueSet(HloInstruction* xla_while);
|
||||
|
||||
|
@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
|
||||
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
|
||||
// Test that a Send forwards its operand to the output tuple at {0}.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
|
||||
auto send = builder.AddInstruction(
|
||||
HloInstruction::CreateSend(param, /*channel_id=*/0));
|
||||
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
EXPECT_EQ(analysis.values().size(), 4);
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
|
||||
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
|
||||
// Test that a RecvDone forwards its operand tuple element at {0} to the
|
||||
// output.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto recv = builder.AddInstruction(
|
||||
HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
|
||||
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
EXPECT_EQ(analysis.values().size(), 3);
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
|
||||
EXPECT_THAT(HloValuesAt(recv_done),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
|
||||
// A simple chain of elementwise operations. No values should interfere.
|
||||
//
|
||||
|
@ -943,7 +943,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kFusion:
|
||||
return kGray;
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
@ -1037,7 +1039,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
|
||||
? ""
|
||||
: StrCat("stride=", VectorString(instr->slice_strides()));
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
return StrCat("channel_id=", instr->channel_id());
|
||||
default:
|
||||
return "";
|
||||
|
@ -371,20 +371,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
|
||||
HloInstruction* operand, int64 channel_id) {
|
||||
// Send instruction produces a tuple of {aliased operand, U32 context}.
|
||||
Shape output_shape = ShapeUtil::MakeTupleShape(
|
||||
{operand->shape(), ShapeUtil::MakeShape(U32, {})});
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
|
||||
WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->channel_id_ = channel_id;
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
|
||||
HloInstruction* operand) {
|
||||
CHECK(operand->opcode() == HloOpcode::kSend)
|
||||
<< "SendDone must take the context operand from Send";
|
||||
auto instruction = WrapUnique(
|
||||
new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->channel_id_ = operand->channel_id();
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
|
||||
const Shape& shape, int64 channel_id) {
|
||||
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
|
||||
// Recv instruction produces a tuple of {receive buffer, U32 context}.
|
||||
Shape output_shape =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
|
||||
instruction->channel_id_ = channel_id;
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
|
||||
HloInstruction* operand) {
|
||||
CHECK(operand->opcode() == HloOpcode::kRecv)
|
||||
<< "RecvDone must take the context operand from Recv";
|
||||
Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->channel_id_ = operand->channel_id();
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
@ -908,7 +938,9 @@ RandomDistribution HloInstruction::random_distribution() const {
|
||||
bool HloInstruction::HasSideEffect() const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kTrace:
|
||||
@ -1164,7 +1196,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
new_operands[4], epsilon(), feature_index());
|
||||
break;
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kTrace:
|
||||
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
|
||||
}
|
||||
@ -1557,8 +1591,10 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -1891,7 +1927,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
})));
|
||||
}
|
||||
|
||||
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
|
||||
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
|
||||
opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
|
||||
extra.push_back(StrCat("channel_id=", channel_id_));
|
||||
}
|
||||
|
||||
@ -2071,8 +2108,10 @@ bool HloInstruction::IsFusable() const {
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
return false;
|
||||
// Only fuse Rng if it is used once, otherwise the random numbers generated
|
||||
// will be different in each fusion. If it is the root (user count = 0)
|
||||
@ -2279,10 +2318,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleCall(this);
|
||||
case HloOpcode::kCustomCall:
|
||||
return visitor->HandleCustomCall(this);
|
||||
case HloOpcode::kSend:
|
||||
return visitor->HandleSend(this);
|
||||
case HloOpcode::kRecv:
|
||||
return visitor->HandleRecv(this);
|
||||
case HloOpcode::kRecvDone:
|
||||
return visitor->HandleRecvDone(this);
|
||||
case HloOpcode::kSend:
|
||||
return visitor->HandleSend(this);
|
||||
case HloOpcode::kSendDone:
|
||||
return visitor->HandleSendDone(this);
|
||||
|
||||
// These opcodes are not handled here.
|
||||
case HloOpcode::kTrace:
|
||||
|
@ -181,18 +181,28 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
tensorflow::StringPiece outfeed_config);
|
||||
|
||||
// Creates a send instruction with the given channel id, which sends the
|
||||
// operand data to a unique receive instruction in another computation that
|
||||
// has the same channel id.
|
||||
// Creates an asynchronous send instruction with the given channel id, which
|
||||
// initiates sending the operand data to a unique receive instruction in
|
||||
// another computation that has the same channel id.
|
||||
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
|
||||
int64 channel_id);
|
||||
|
||||
// Creates a receive instruction with the given channel id, which receives
|
||||
// data of the given shape from a unique send instruction in another
|
||||
// computation that has the same channel id.
|
||||
// Blocks until data transfer for the Send instruction (operand) is complete.
|
||||
// The operand must be kSend.
|
||||
static std::unique_ptr<HloInstruction> CreateSendDone(
|
||||
HloInstruction* operand);
|
||||
|
||||
// Creates an asynchronous receive instruction with the given channel id,
|
||||
// which allocates resources to receive data of the given shape from a unique
|
||||
// send instruction in another computation that has the same channel id.
|
||||
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
|
||||
int64 channel_id);
|
||||
|
||||
// Blocks until data transfer for the Recv instruction (operand) is complete
|
||||
// and returns the receive buffer. The operand must be kRecv.
|
||||
static std::unique_ptr<HloInstruction> CreateRecvDone(
|
||||
HloInstruction* operand);
|
||||
|
||||
// Creates a slice instruction, where the operand is sliced by the given
|
||||
// start/limit indices.
|
||||
static std::unique_ptr<HloInstruction> CreateSlice(
|
||||
|
@ -121,6 +121,7 @@ HLO_MATCHER(Outfeed);
|
||||
HLO_MATCHER(Pad);
|
||||
HLO_MATCHER(Power);
|
||||
HLO_MATCHER(Recv);
|
||||
HLO_MATCHER(RecvDone);
|
||||
HLO_MATCHER(Reduce);
|
||||
HLO_MATCHER(ReducePrecision);
|
||||
HLO_MATCHER(ReduceWindow);
|
||||
@ -131,6 +132,7 @@ HLO_MATCHER(Rng);
|
||||
HLO_MATCHER(Select);
|
||||
HLO_MATCHER(SelectAndScatter);
|
||||
HLO_MATCHER(Send);
|
||||
HLO_MATCHER(SendDone);
|
||||
HLO_MATCHER(ShiftLeft);
|
||||
HLO_MATCHER(ShiftRightLogical);
|
||||
HLO_MATCHER(ShiftRightArithmetic);
|
||||
|
@ -97,6 +97,7 @@ namespace xla {
|
||||
V(kPower, "power") \
|
||||
V(kReal, "real") \
|
||||
V(kRecv, "recv") \
|
||||
V(kRecvDone, "recv-done") \
|
||||
V(kReduce, "reduce") \
|
||||
V(kReducePrecision, "reduce-precision") \
|
||||
V(kReduceWindow, "reduce-window") \
|
||||
@ -108,6 +109,7 @@ namespace xla {
|
||||
V(kSelect, "select") \
|
||||
V(kSelectAndScatter, "select-and-scatter") \
|
||||
V(kSend, "send") \
|
||||
V(kSendDone, "send-done") \
|
||||
V(kShiftLeft, "shift-left") \
|
||||
V(kShiftRightArithmetic, "shift-right-arithmetic") \
|
||||
V(kShiftRightLogical, "shift-right-logical") \
|
||||
|
@ -66,7 +66,9 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kWhile:
|
||||
return false;
|
||||
|
@ -270,12 +270,40 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
pad->padding_config()));
|
||||
}
|
||||
|
||||
Status HandleSend(HloInstruction*) override {
|
||||
return tensorflow::Status::OK();
|
||||
Status HandleSend(HloInstruction* send) override {
|
||||
TF_RET_CHECK(send->users().size() == 1);
|
||||
const HloInstruction* send_done = send->users()[0];
|
||||
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
|
||||
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
|
||||
return CheckShape(
|
||||
send, ShapeUtil::MakeTupleShape(
|
||||
{send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
|
||||
}
|
||||
|
||||
Status HandleRecv(HloInstruction*) override {
|
||||
return tensorflow::Status::OK();
|
||||
Status HandleSendDone(HloInstruction* send_done) override {
|
||||
TF_RET_CHECK(send_done->operands().size() == 1);
|
||||
const HloInstruction* send = send_done->operand(0);
|
||||
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
|
||||
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
|
||||
return CheckShape(send_done, ShapeUtil::MakeNil());
|
||||
}
|
||||
|
||||
Status HandleRecv(HloInstruction* recv) override {
|
||||
TF_RET_CHECK(recv->users().size() == 1);
|
||||
const HloInstruction* recv_done = recv->users()[0];
|
||||
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
|
||||
return CheckShape(recv,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
|
||||
}
|
||||
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override {
|
||||
TF_RET_CHECK(recv_done->operands().size() == 1);
|
||||
const HloInstruction* recv = recv_done->operand(0);
|
||||
TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
|
||||
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
|
||||
return CheckShape(recv_done, recv->shape().tuple_shapes(0));
|
||||
}
|
||||
|
||||
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
|
||||
@ -365,6 +393,19 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
instruction->opcode(), instruction->operands()));
|
||||
}
|
||||
|
||||
// Checks if the given two instructions shares the same channel id.
|
||||
Status CheckSameChannel(const HloInstruction* instr1,
|
||||
const HloInstruction* instr2) {
|
||||
if (instr1->channel_id() != instr2->channel_id()) {
|
||||
return FailedPrecondition(
|
||||
"Expected to have the same channel id, actual channel ids are: %s "
|
||||
"(%lld), %s (%lld)",
|
||||
instr1->ToString().c_str(), instr1->channel_id(),
|
||||
instr2->ToString().c_str(), instr2->channel_id());
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Returns the size of a Shape in bytes.
|
||||
const std::function<int64(const Shape&)> shape_size_fn_;
|
||||
};
|
||||
|
@ -113,7 +113,9 @@ namespace xla {
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -104,6 +104,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
|
||||
// RecvDone doesn't create a new buffer but rather aliases its input (Recv)
|
||||
// tuple element at {0} to its output.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
|
||||
// Send creates new buffers for the top-level tuple and the context (tuple
|
||||
// element at {1}). Tuple element at {0} is an alias of the Send operand, so
|
||||
// we don't need to create a new Logical Buffer for that.
|
||||
NewLogicalBuffer(send, /*index=*/{});
|
||||
NewLogicalBuffer(send, /*index=*/{1});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
|
||||
// A Tuple instruction only creates the top-level buffer.
|
||||
NewLogicalBuffer(tuple, /*index=*/{});
|
||||
|
@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
|
||||
// A map from the buffer ID to the logical buffer
|
||||
|
@ -253,6 +253,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
|
||||
// RecvDone aliases its input (Recv) tuple element {0} to its output.
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
|
||||
const PointsToSet& operand_points_to_set =
|
||||
GetPointsToSet(recv_done->operand(0));
|
||||
|
||||
// Recursively copy the points to set of the operand tuple {0}.
|
||||
points_to_set.ForEachMutableElement(
|
||||
[this, &points_to_set, &operand_points_to_set](
|
||||
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
||||
ShapeIndex src_index({0});
|
||||
for (auto element : index) {
|
||||
src_index.push_back(element);
|
||||
}
|
||||
*buffers = operand_points_to_set.element(src_index);
|
||||
for (auto& tuple_source :
|
||||
operand_points_to_set.tuple_sources(src_index)) {
|
||||
points_to_set.add_tuple_source(index, tuple_source);
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
|
||||
// Send creates a tuple of {aliased operand, U32 context}.
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
|
||||
|
||||
// Creates the points to set for the tuple and its element at {1}.
|
||||
auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
|
||||
top_buffer->push_back(
|
||||
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
|
||||
points_to_set.add_tuple_source({}, send);
|
||||
|
||||
auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
|
||||
context_buffer->push_back(
|
||||
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
|
||||
|
||||
// Recursively copy the points to set of the operand to output tuple {0}.
|
||||
const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
|
||||
operand_points_to_set.ForEachElement(
|
||||
[&points_to_set, &operand_points_to_set](
|
||||
const ShapeIndex& src_index,
|
||||
const PointsToSet::BufferList& points_to) {
|
||||
ShapeIndex target_index({0});
|
||||
for (auto element : src_index) {
|
||||
target_index.push_back(element);
|
||||
}
|
||||
*points_to_set.mutable_element(target_index) = points_to;
|
||||
|
||||
for (HloInstruction* tuple :
|
||||
operand_points_to_set.tuple_sources(src_index)) {
|
||||
points_to_set.add_tuple_source(target_index, tuple);
|
||||
}
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
|
||||
|
@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
|
||||
string ToString() const;
|
||||
|
@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
|
||||
{constant1, constant2, copy});
|
||||
}
|
||||
|
||||
TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
|
||||
// Send forwards its operand to the output tuple at {0}.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto send = builder.AddInstruction(
|
||||
HloInstruction::CreateSend(constant, /*channel_id=*/0));
|
||||
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
|
||||
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
|
||||
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
|
||||
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
|
||||
|
||||
ExpectHasTopLevelBuffers(
|
||||
points_to_analysis_->GetPointsToSet(send).element({}), {send});
|
||||
ExpectHasTopLevelBuffers(
|
||||
points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
|
||||
ExpectHasTopLevelBuffers(
|
||||
points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
|
||||
{send_done});
|
||||
ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
|
||||
}
|
||||
|
||||
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
|
||||
// RecvDone forwards its operand tuple element at {0} to the output.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
|
||||
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
|
||||
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
|
||||
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
|
||||
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
|
||||
|
||||
ExpectHasTopLevelBuffers(
|
||||
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
|
||||
ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
|
||||
}
|
||||
|
||||
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
|
||||
// Select from two different tuples. This should create an ambiguous points to
|
||||
// set containing the union of both sides.
|
||||
|
@ -2927,8 +2927,9 @@ void ComputationLowerer::Visit(
|
||||
|
||||
case OpRequest::kRecvRequest: {
|
||||
const RecvRequest& recv_request = request.request().recv_request();
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateRecv(
|
||||
HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
|
||||
request.output_shape(), recv_request.channel_handle().handle()));
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -3120,8 +3121,9 @@ void ComputationLowerer::Visit(
|
||||
case OpRequest::kSendRequest: {
|
||||
const SendRequest& send_request = request.request().send_request();
|
||||
HloInstruction* operand = lookup_instruction(send_request.operand());
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateSend(
|
||||
HloInstruction* send = add_instruction(HloInstruction::CreateSend(
|
||||
operand, send_request.channel_handle().handle()));
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) {
|
||||
|
||||
static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
|
||||
if (instr->opcode() == HloOpcode::kSend ||
|
||||
instr->opcode() == HloOpcode::kRecv) {
|
||||
instr->opcode() == HloOpcode::kSendDone ||
|
||||
instr->opcode() == HloOpcode::kRecv ||
|
||||
instr->opcode() == HloOpcode::kRecvDone) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& subcomp : instr->called_computations()) {
|
||||
|
@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
|
||||
auto* while_op = computation->root_instruction();
|
||||
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
auto* while_body = while_op->while_body();
|
||||
while_body->AddInstruction(HloInstruction::CreateSend(
|
||||
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
|
||||
while_body->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
|
||||
/*channel_id=*/0));
|
||||
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
|
||||
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
|
||||
}
|
||||
|
||||
@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
|
||||
auto* while_op = computation->root_instruction();
|
||||
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
auto* while_body = while_op->while_body();
|
||||
while_body->AddInstruction(
|
||||
auto* recv = while_body->AddInstruction(
|
||||
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
|
||||
/*channel_id=*/0));
|
||||
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
|
||||
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -442,7 +442,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateRecv(shape, *channel_id));
|
||||
HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRecvDone: {
|
||||
optional<int64> channel_id;
|
||||
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
if (channel_id != operands[0]->channel_id()) {
|
||||
return false;
|
||||
}
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kSend: {
|
||||
@ -456,6 +470,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
HloInstruction::CreateSend(operands[0], *channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kSendDone: {
|
||||
optional<int64> channel_id;
|
||||
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
if (channel_id != operands[0]->channel_id()) {
|
||||
return false;
|
||||
}
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
optional<int64> index;
|
||||
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
|
||||
|
@ -226,9 +226,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
||||
R"(HloModule TwoSendRecvBothWayRecvFist_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
|
||||
ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
|
||||
%recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
|
||||
ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
|
||||
%constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
|
||||
%send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
|
||||
}
|
||||
|
||||
)"
|
||||
@ -522,9 +524,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
|
||||
const string original = R"(HloModule unexpected_attr_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
%recv = (f32[], u32[]) recv(), channel_id=15
|
||||
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = () send(f32[] %constant), channel_id=16, calls=%recv
|
||||
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
|
||||
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
)";
|
||||
@ -536,9 +540,11 @@ TEST_F(HloParserTest, MissingAttribute) {
|
||||
const string original = R"(HloModule missing_attr_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
%recv = (f32[], u32[]) recv(), channel_id=15
|
||||
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(-2.1)
|
||||
%send = () send(f32[] %constant)
|
||||
%send = (f32[], u32[]) send(f32[] %constant)
|
||||
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
)";
|
||||
@ -550,9 +556,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
|
||||
const string original = R"(HloModule pre_not_found_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
%recv = (f32[], u32[]) recv(), channel_id=15
|
||||
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
|
||||
%send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
|
||||
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
)";
|
||||
|
@ -901,6 +901,95 @@ are all 0. Figure below shows examples of different `edge_padding` and
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
|
||||
</div>
|
||||
|
||||
## Recv
|
||||
|
||||
See also
|
||||
[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
|
||||
|
||||
<b> `Recv(shape, channel_handle)` </b>
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| ---------------- | --------------- | ------------------------------------ |
|
||||
| `shape` | `Shape` | shape of the data to receive |
|
||||
| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
|
||||
|
||||
Receives data of the given shape from a `Send` instruction in another
|
||||
computation that shares the same channel handle. Returns a
|
||||
ComputationDataHandle for the received data.
|
||||
|
||||
The client API of `Recv` operation represents synchronous communication.
|
||||
However, the instruction is internally decomposed into 2 HLO instructions
|
||||
(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
|
||||
[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
|
||||
|
||||
<b>`Recv(const Shape& shape, int64 channel_id)`</b>
|
||||
|
||||
Allocates resources required to receive data from a `Send` instruction with the
|
||||
same channel_id. Returns a context for the allocated resources, which is used
|
||||
by a following `RecvDone` instruction to wait for the completion of the data
|
||||
transfer. The context is a tuple of {receive buffer (shape), request identifier
|
||||
(U32)} and it can only be used by a `RecvDone` instruction.
|
||||
|
||||
<b> `RecvDone(HloInstruction context)` </b>
|
||||
|
||||
Given a context created by a `Recv` instruction, waits for the data transfer to
|
||||
complete and returns the received data.
|
||||
|
||||
## Send
|
||||
|
||||
See also
|
||||
[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
|
||||
|
||||
<b> `Send(operand, channel_handle)` </b>
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| ---------------- | ----------------------- | -------------------------------- |
|
||||
| `operand` | `ComputationDataHandle` | data to send (array of type T) |
|
||||
| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
|
||||
|
||||
Sends the given operand data to a `Recv` instruction in another computation
|
||||
that shares the same channel handle. Does not return any data.
|
||||
|
||||
Similar to the `Recv` operation, the client API of `Send` operation represents
|
||||
synchronous communication, and is internally decomposed into 2 HLO instructions
|
||||
(`Send` and `SendDone`) to enable asynchronous data transfers. See also
|
||||
[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
|
||||
|
||||
<b>`Send(HloInstruction operand, int64 channel_id)`</b>
|
||||
|
||||
Initiates an asynchronous transfer of the operand to the resources allocated by
|
||||
the `Recv` instruction with the same channel id. Returns a context, which is
|
||||
used by a following `SendDone` instruction to wait for the completion of the
|
||||
data transfer. The context is a tuple of {operand (shape), request identifier
|
||||
(U32)} and it can only be used by a `SendDone` instruction.
|
||||
|
||||
<b> `SendDone(HloInstruction context)` </b>
|
||||
|
||||
Given a context created by a `Send` instruction, waits for the data transfer to
|
||||
complete. The instruction does not return any data.
|
||||
|
||||
<b> Scheduling of channel instructions </b>
|
||||
|
||||
The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
|
||||
`Send`, `SendDone`) is as below.
|
||||
|
||||
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:70%" src="../../images/send_recv_order.png">
|
||||
</div>
|
||||
|
||||
* `Recv` happens before `Send`
|
||||
* `Send` happens before `RecvDone`
|
||||
* `Recv` happens before `RecvDone`
|
||||
* `Send` happens before `SendDone`
|
||||
|
||||
When the backend compilers generate a linear schedule for each computation that
|
||||
communicates via channel instructions, there must not be cycles across the
|
||||
computations. For example, below schedules lead to deadlocks.
|
||||
|
||||
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/send_recv_schedule.png">
|
||||
</div>
|
||||
|
||||
## Reduce
|
||||
|
||||
See also
|
||||
|
Loading…
Reference in New Issue
Block a user