Add some XLA frontend attribute names.

PiperOrigin-RevId: 290190699
Change-Id: I421510149dbc759fbe3e06a4990502d4772962b5
This commit is contained in:
Tong Shen 2020-01-16 18:58:04 -08:00 committed by TensorFlower Gardener
parent 18645e7a3c
commit 6b525249b8
4 changed files with 46 additions and 0 deletions

View File

@ -34,6 +34,15 @@ const char kXlaIsPlaceholderForTailOcAttrName[] =
const char kXlaOriginalOutsideCompilationNodeName[] =
"_xla_original_oc_node_name";
const char kXlaHostTransferRendezvousNameAttr[] =
"_xla_host_transfer_rendezvous";
const char kXlaHostTransferOriginalTypeAttr[] =
"_xla_host_transfer_original_type";
const char kXlaHostTransferIsLowerBitsAttr[] =
"_xla_host_transfer_is_lower_bits";
Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
return errors::InvalidArgument("Node ", node->DebugString(),

View File

@ -64,6 +64,18 @@ bool HasSideEffectingNodes(const Graph& g);
Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr,
std::map<string, int>* host_compute_core);
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
extern const char kXlaHostTransferRendezvousNameAttr[];
// XLA frontend attribute name which specifies original host transfer type.
// Value is XLA primitive type in lower case.
extern const char kXlaHostTransferOriginalTypeAttr[];
// XLA frontend attribute name which specifies whether a host transfer
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
// or "false".
extern const char kXlaHostTransferIsLowerBitsAttr[];
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_

View File

@ -148,5 +148,25 @@ int64 NextChannelId(const HloModule& module) {
return next_channel_id;
}
bool HasX64TransformedHostTransfer(const HloModule& module) {
for (auto computation : module.computations()) {
for (auto hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kSend) {
auto send = DynCast<HloSendInstruction>(hlo);
if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
return true;
}
} else if (hlo->opcode() == HloOpcode::kRecv) {
auto recv = DynCast<HloRecvInstruction>(hlo);
if (recv->is_host_transfer() &&
recv->shape().tuple_shapes(0).IsTuple()) {
return true;
}
}
}
}
return false;
}
} // namespace hlo_query
} // namespace xla

View File

@ -81,6 +81,11 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
// (for HloChannelInstructions).
int64 NextChannelId(const HloModule& module);
// Returns whether the module contains host send/recv with X64 data type.
// This function is called after X64Rewriter, so X64 host transfers are already
// rewritten into tuple shaped transfers.
bool HasX64TransformedHostTransfer(const HloModule& module);
} // namespace hlo_query
} // namespace xla