From 6b525249b8be9db9fd58a6e22696229fac538047 Mon Sep 17 00:00:00 2001
From: Tong Shen <endlessroad@google.com>
Date: Thu, 16 Jan 2020 18:58:04 -0800
Subject: [PATCH] Add some XLA frontend attribute names.

PiperOrigin-RevId: 290190699
Change-Id: I421510149dbc759fbe3e06a4990502d4772962b5
---
 .../compiler/tf2xla/side_effect_util.cc       |  9 +++++++++
 tensorflow/compiler/tf2xla/side_effect_util.h | 12 +++++++++++
 tensorflow/compiler/xla/service/hlo_query.cc  | 20 +++++++++++++++++++
 tensorflow/compiler/xla/service/hlo_query.h   |  5 +++++
 4 files changed, 46 insertions(+)

diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
index d6a6540f072..10774cef6d1 100644
--- a/tensorflow/compiler/tf2xla/side_effect_util.cc
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -34,6 +34,15 @@ const char kXlaIsPlaceholderForTailOcAttrName[] =
 const char kXlaOriginalOutsideCompilationNodeName[] =
     "_xla_original_oc_node_name";
 
+const char kXlaHostTransferRendezvousNameAttr[] =
+    "_xla_host_transfer_rendezvous";
+
+const char kXlaHostTransferOriginalTypeAttr[] =
+    "_xla_host_transfer_original_type";
+
+const char kXlaHostTransferIsLowerBitsAttr[] =
+    "_xla_host_transfer_is_lower_bits";
+
 Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
   if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
     return errors::InvalidArgument("Node ", node->DebugString(),
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
index f91fe75c8a4..738be06f16a 100644
--- a/tensorflow/compiler/tf2xla/side_effect_util.h
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -64,6 +64,18 @@ bool HasSideEffectingNodes(const Graph& g);
 Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr,
                                 std::map<string, int>* host_compute_core);
 
+// XLA frontend attribute name which specifies TensorFlow rendezvous name.
+extern const char kXlaHostTransferRendezvousNameAttr[];
+
+// XLA frontend attribute name which specifies original host transfer type.
+// Value is XLA primitive type in lower case.
+extern const char kXlaHostTransferOriginalTypeAttr[];
+
+// XLA frontend attribute name which specifies whether a host transfer
+// instruction is lower bits for a splitted X64 host transfer. Value is "true"
+// or "false".
+extern const char kXlaHostTransferIsLowerBitsAttr[];
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index f6ee4096b0c..46bc6574f9d 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -148,5 +148,25 @@ int64 NextChannelId(const HloModule& module) {
   return next_channel_id;
 }
 
+bool HasX64TransformedHostTransfer(const HloModule& module) {
+  for (auto computation : module.computations()) {
+    for (auto hlo : computation->instructions()) {
+      if (hlo->opcode() == HloOpcode::kSend) {
+        auto send = DynCast<HloSendInstruction>(hlo);
+        if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
+          return true;
+        }
+      } else if (hlo->opcode() == HloOpcode::kRecv) {
+        auto recv = DynCast<HloRecvInstruction>(hlo);
+        if (recv->is_host_transfer() &&
+            recv->shape().tuple_shapes(0).IsTuple()) {
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
 }  // namespace hlo_query
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h
index b7fbc465dcb..e1a4e069cc3 100644
--- a/tensorflow/compiler/xla/service/hlo_query.h
+++ b/tensorflow/compiler/xla/service/hlo_query.h
@@ -81,6 +81,11 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
 // (for HloChannelInstructions).
 int64 NextChannelId(const HloModule& module);
 
+// Returns whether the module contains host send/recv with X64 data type.
+// This function is called after X64Rewriter, so X64 host transfers are already
+// rewritten into tuple shaped transfers.
+bool HasX64TransformedHostTransfer(const HloModule& module);
+
 }  // namespace hlo_query
 }  // namespace xla