Add some XLA frontend attribute names.
PiperOrigin-RevId: 290190699 Change-Id: I421510149dbc759fbe3e06a4990502d4772962b5
This commit is contained in:
parent
18645e7a3c
commit
6b525249b8
@ -34,6 +34,15 @@ const char kXlaIsPlaceholderForTailOcAttrName[] =
|
||||
const char kXlaOriginalOutsideCompilationNodeName[] =
|
||||
"_xla_original_oc_node_name";
|
||||
|
||||
const char kXlaHostTransferRendezvousNameAttr[] =
|
||||
"_xla_host_transfer_rendezvous";
|
||||
|
||||
const char kXlaHostTransferOriginalTypeAttr[] =
|
||||
"_xla_host_transfer_original_type";
|
||||
|
||||
const char kXlaHostTransferIsLowerBitsAttr[] =
|
||||
"_xla_host_transfer_is_lower_bits";
|
||||
|
||||
Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
|
||||
if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
|
||||
return errors::InvalidArgument("Node ", node->DebugString(),
|
||||
|
@ -64,6 +64,18 @@ bool HasSideEffectingNodes(const Graph& g);
|
||||
Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr,
|
||||
std::map<string, int>* host_compute_core);
|
||||
|
||||
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
|
||||
extern const char kXlaHostTransferRendezvousNameAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies original host transfer type.
|
||||
// Value is XLA primitive type in lower case.
|
||||
extern const char kXlaHostTransferOriginalTypeAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies whether a host transfer
|
||||
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
|
||||
// or "false".
|
||||
extern const char kXlaHostTransferIsLowerBitsAttr[];
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
|
||||
|
@ -148,5 +148,25 @@ int64 NextChannelId(const HloModule& module) {
|
||||
return next_channel_id;
|
||||
}
|
||||
|
||||
bool HasX64TransformedHostTransfer(const HloModule& module) {
|
||||
for (auto computation : module.computations()) {
|
||||
for (auto hlo : computation->instructions()) {
|
||||
if (hlo->opcode() == HloOpcode::kSend) {
|
||||
auto send = DynCast<HloSendInstruction>(hlo);
|
||||
if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
|
||||
return true;
|
||||
}
|
||||
} else if (hlo->opcode() == HloOpcode::kRecv) {
|
||||
auto recv = DynCast<HloRecvInstruction>(hlo);
|
||||
if (recv->is_host_transfer() &&
|
||||
recv->shape().tuple_shapes(0).IsTuple()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace hlo_query
|
||||
} // namespace xla
|
||||
|
@ -81,6 +81,11 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
|
||||
// (for HloChannelInstructions).
|
||||
int64 NextChannelId(const HloModule& module);
|
||||
|
||||
// Returns whether the module contains host send/recv with X64 data type.
|
||||
// This function is called after X64Rewriter, so X64 host transfers are already
|
||||
// rewritten into tuple shaped transfers.
|
||||
bool HasX64TransformedHostTransfer(const HloModule& module);
|
||||
|
||||
} // namespace hlo_query
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user