diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index d6a6540f072..10774cef6d1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -34,6 +34,15 @@ const char kXlaIsPlaceholderForTailOcAttrName[] = const char kXlaOriginalOutsideCompilationNodeName[] = "_xla_original_oc_node_name"; +const char kXlaHostTransferRendezvousNameAttr[] = + "_xla_host_transfer_rendezvous"; + +const char kXlaHostTransferOriginalTypeAttr[] = + "_xla_host_transfer_original_type"; + +const char kXlaHostTransferIsLowerBitsAttr[] = + "_xla_host_transfer_is_lower_bits"; + Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { return errors::InvalidArgument("Node ", node->DebugString(), diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f91fe75c8a4..738be06f16a 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -64,6 +64,18 @@ bool HasSideEffectingNodes(const Graph& g); Status ParseHostComputeCoreList(absl::Span list_from_attr, std::map* host_compute_core); +// XLA frontend attribute name which specifies TensorFlow rendezvous name. +extern const char kXlaHostTransferRendezvousNameAttr[]; + +// XLA frontend attribute name which specifies original host transfer type. +// Value is XLA primitive type in lower case. +extern const char kXlaHostTransferOriginalTypeAttr[]; + +// XLA frontend attribute name which specifies whether a host transfer +// instruction is lower bits for a splitted X64 host transfer. Value is "true" +// or "false". +extern const char kXlaHostTransferIsLowerBitsAttr[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index f6ee4096b0c..46bc6574f9d 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -148,5 +148,25 @@ int64 NextChannelId(const HloModule& module) { return next_channel_id; } +bool HasX64TransformedHostTransfer(const HloModule& module) { + for (auto computation : module.computations()) { + for (auto hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kSend) { + auto send = DynCast(hlo); + if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) { + return true; + } + } else if (hlo->opcode() == HloOpcode::kRecv) { + auto recv = DynCast(hlo); + if (recv->is_host_transfer() && + recv->shape().tuple_shapes(0).IsTuple()) { + return true; + } + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index b7fbc465dcb..e1a4e069cc3 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -81,6 +81,11 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module); // (for HloChannelInstructions). int64 NextChannelId(const HloModule& module); +// Returns whether the module contains host send/recv with X64 data type. +// This function is called after X64Rewriter, so X64 host transfers are already +// rewritten into tuple shaped transfers. +bool HasX64TransformedHostTransfer(const HloModule& module); + } // namespace hlo_query } // namespace xla