diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 10774cef6d1..d6a6540f072 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -34,15 +34,6 @@ const char kXlaIsPlaceholderForTailOcAttrName[] = const char kXlaOriginalOutsideCompilationNodeName[] = "_xla_original_oc_node_name"; -const char kXlaHostTransferRendezvousNameAttr[] = - "_xla_host_transfer_rendezvous"; - -const char kXlaHostTransferOriginalTypeAttr[] = - "_xla_host_transfer_original_type"; - -const char kXlaHostTransferIsLowerBitsAttr[] = - "_xla_host_transfer_is_lower_bits"; - Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { return errors::InvalidArgument("Node ", node->DebugString(), diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index 738be06f16a..f91fe75c8a4 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -64,18 +64,6 @@ bool HasSideEffectingNodes(const Graph& g); Status ParseHostComputeCoreList(absl::Span list_from_attr, std::map* host_compute_core); -// XLA frontend attribute name which specifies TensorFlow rendezvous name. -extern const char kXlaHostTransferRendezvousNameAttr[]; - -// XLA frontend attribute name which specifies original host transfer type. -// Value is XLA primitive type in lower case. -extern const char kXlaHostTransferOriginalTypeAttr[]; - -// XLA frontend attribute name which specifies whether a host transfer -// instruction is lower bits for a splitted X64 host transfer. Value is "true" -// or "false". -extern const char kXlaHostTransferIsLowerBitsAttr[]; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 6e996186c28..79be98568b9 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -964,6 +964,12 @@ cc_library( hdrs = ["union_find.h"], ) +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index fd6d62ac2f7..e402120fb15 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -209,6 +209,7 @@ class Layout { return *this; } static constexpr int64 kDefaultMemorySpace = 0; + static constexpr int64 kGenericFastMemorySpace = 1; int64 memory_space() const { return memory_space_; } Layout& set_memory_space(int64 value) { memory_space_ = value; diff --git a/tensorflow/compiler/xla/side_effect_util.cc b/tensorflow/compiler/xla/side_effect_util.cc new file mode 100644 index 00000000000..337d5acd3e6 --- /dev/null +++ b/tensorflow/compiler/xla/side_effect_util.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/side_effect_util.h" + +namespace xla { + +const char kXlaHostTransferRendezvousNameAttr[] = + "_xla_host_transfer_rendezvous"; + +const char kXlaHostTransferOriginalTypeAttr[] = + "_xla_host_transfer_original_type"; + +const char kXlaHostTransferIsLowerBitsAttr[] = + "_xla_host_transfer_is_lower_bits"; + +} // namespace xla diff --git a/tensorflow/compiler/xla/side_effect_util.h b/tensorflow/compiler/xla/side_effect_util.h new file mode 100644 index 00000000000..86cc0e6a06c --- /dev/null +++ b/tensorflow/compiler/xla/side_effect_util.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_ + +namespace xla { + +// XLA frontend attribute name which specifies TensorFlow rendezvous name. +extern const char kXlaHostTransferRendezvousNameAttr[]; + +// XLA frontend attribute name which specifies original host transfer type. +// Value is XLA primitive type in lower case. +extern const char kXlaHostTransferOriginalTypeAttr[]; + +// XLA frontend attribute name which specifies whether a host transfer +// instruction is lower bits for a splitted X64 host transfer. Value is "true" +// or "false". +extern const char kXlaHostTransferIsLowerBitsAttr[]; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/core/tpu/kernels/xla/BUILD b/tensorflow/core/tpu/kernels/xla/BUILD index 774365c31c5..34a8375c54d 100644 --- a/tensorflow/core/tpu/kernels/xla/BUILD +++ b/tensorflow/core/tpu/kernels/xla/BUILD @@ -31,6 +31,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:side_effect_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc index be3ee1c9d24..52bff5cd9de 100644 --- a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/side_effect_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -152,8 +153,9 @@ class HostComputeOp : public XlaOpKernel { input_shapes[i], &xla_shape)); // Specify frontend attributes. xla::FrontendAttributes attrs; - (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name; - (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] = + (*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = + channel_name; + (*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] = xla::primitive_util::LowercasePrimitiveTypeName( xla_shape.element_type()); b->SetFrontendAttributes(attrs); @@ -211,8 +213,9 @@ class HostComputeOp : public XlaOpKernel { const string channel_name = absl::StrCat(key_, "_htod_", i); // Specify frontend attributes. xla::FrontendAttributes attrs; - (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name; - (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] = + (*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = + channel_name; + (*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] = xla::primitive_util::LowercasePrimitiveTypeName( xla_output_shapes->at(i).element_type()); b->SetFrontendAttributes(attrs); @@ -416,8 +419,8 @@ class SendToHostOp : public XlaOpKernel { &xla_shape)); // Specify frontend attributes. xla::FrontendAttributes attrs; - (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_; - (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] = + (*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_; + (*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] = xla::primitive_util::LowercasePrimitiveTypeName( xla_shape.element_type()); b->SetFrontendAttributes(attrs); @@ -468,8 +471,8 @@ class RecvFromHostOp : public XlaOpKernel { ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape)); // Specify frontend attributes. xla::FrontendAttributes attrs; - (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_; - (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] = + (*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_; + (*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] = xla::primitive_util::LowercasePrimitiveTypeName( xla_shape.element_type()); b->SetFrontendAttributes(attrs);