From 6b525249b8be9db9fd58a6e22696229fac538047 Mon Sep 17 00:00:00 2001 From: Tong Shen <endlessroad@google.com> Date: Thu, 16 Jan 2020 18:58:04 -0800 Subject: [PATCH] Add some XLA frontend attribute names. PiperOrigin-RevId: 290190699 Change-Id: I421510149dbc759fbe3e06a4990502d4772962b5 --- .../compiler/tf2xla/side_effect_util.cc | 9 +++++++++ tensorflow/compiler/tf2xla/side_effect_util.h | 12 +++++++++++ tensorflow/compiler/xla/service/hlo_query.cc | 20 +++++++++++++++++++ tensorflow/compiler/xla/service/hlo_query.h | 5 +++++ 4 files changed, 46 insertions(+) diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index d6a6540f072..10774cef6d1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -34,6 +34,15 @@ const char kXlaIsPlaceholderForTailOcAttrName[] = const char kXlaOriginalOutsideCompilationNodeName[] = "_xla_original_oc_node_name"; +const char kXlaHostTransferRendezvousNameAttr[] = + "_xla_host_transfer_rendezvous"; + +const char kXlaHostTransferOriginalTypeAttr[] = + "_xla_host_transfer_original_type"; + +const char kXlaHostTransferIsLowerBitsAttr[] = + "_xla_host_transfer_is_lower_bits"; + Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { return errors::InvalidArgument("Node ", node->DebugString(), diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f91fe75c8a4..738be06f16a 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -64,6 +64,18 @@ bool HasSideEffectingNodes(const Graph& g); Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr, std::map<string, int>* host_compute_core); +// XLA frontend attribute name which specifies TensorFlow rendezvous name. +extern const char kXlaHostTransferRendezvousNameAttr[]; + +// XLA frontend attribute name which specifies original host transfer type. +// Value is XLA primitive type in lower case. +extern const char kXlaHostTransferOriginalTypeAttr[]; + +// XLA frontend attribute name which specifies whether a host transfer +// instruction is lower bits for a splitted X64 host transfer. Value is "true" +// or "false". +extern const char kXlaHostTransferIsLowerBitsAttr[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index f6ee4096b0c..46bc6574f9d 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -148,5 +148,25 @@ int64 NextChannelId(const HloModule& module) { return next_channel_id; } +bool HasX64TransformedHostTransfer(const HloModule& module) { + for (auto computation : module.computations()) { + for (auto hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kSend) { + auto send = DynCast<HloSendInstruction>(hlo); + if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) { + return true; + } + } else if (hlo->opcode() == HloOpcode::kRecv) { + auto recv = DynCast<HloRecvInstruction>(hlo); + if (recv->is_host_transfer() && + recv->shape().tuple_shapes(0).IsTuple()) { + return true; + } + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index b7fbc465dcb..e1a4e069cc3 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -81,6 +81,11 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module); // (for HloChannelInstructions). int64 NextChannelId(const HloModule& module); +// Returns whether the module contains host send/recv with X64 data type. +// This function is called after X64Rewriter, so X64 host transfers are already +// rewritten into tuple shaped transfers. +bool HasX64TransformedHostTransfer(const HloModule& module); + } // namespace hlo_query } // namespace xla