[XLA] Break dependencies of XLA on TF.

PiperOrigin-RevId: 352649426
Change-Id: Ia65d1ff9e4df69467d5ad7264bb052c54a661c0a
This commit is contained in:
Peter Hawkins 2021-01-19 14:02:19 -08:00 committed by TensorFlower Gardener
parent daa61db597
commit b885741721
8 changed files with 83 additions and 29 deletions

View File

@ -34,15 +34,6 @@ const char kXlaIsPlaceholderForTailOcAttrName[] =
const char kXlaOriginalOutsideCompilationNodeName[] =
"_xla_original_oc_node_name";
const char kXlaHostTransferRendezvousNameAttr[] =
"_xla_host_transfer_rendezvous";
const char kXlaHostTransferOriginalTypeAttr[] =
"_xla_host_transfer_original_type";
const char kXlaHostTransferIsLowerBitsAttr[] =
"_xla_host_transfer_is_lower_bits";
Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
return errors::InvalidArgument("Node ", node->DebugString(),

View File

@ -64,18 +64,6 @@ bool HasSideEffectingNodes(const Graph& g);
Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr,
std::map<string, int>* host_compute_core);
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
extern const char kXlaHostTransferRendezvousNameAttr[];
// XLA frontend attribute name which specifies original host transfer type.
// Value is XLA primitive type in lower case.
extern const char kXlaHostTransferOriginalTypeAttr[];
// XLA frontend attribute name which specifies whether a host transfer
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
// or "false".
extern const char kXlaHostTransferIsLowerBitsAttr[];
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_

View File

@ -964,6 +964,12 @@ cc_library(
hdrs = ["union_find.h"],
)
cc_library(
name = "side_effect_util",
srcs = ["side_effect_util.cc"],
hdrs = ["side_effect_util.h"],
)
# -----------------------------------------------------------------------------
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.

View File

@ -209,6 +209,7 @@ class Layout {
return *this;
}
static constexpr int64 kDefaultMemorySpace = 0;
static constexpr int64 kGenericFastMemorySpace = 1;
int64 memory_space() const { return memory_space_; }
Layout& set_memory_space(int64 value) {
memory_space_ = value;

View File

@ -0,0 +1,29 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/side_effect_util.h"
namespace xla {
const char kXlaHostTransferRendezvousNameAttr[] =
"_xla_host_transfer_rendezvous";
const char kXlaHostTransferOriginalTypeAttr[] =
"_xla_host_transfer_original_type";
const char kXlaHostTransferIsLowerBitsAttr[] =
"_xla_host_transfer_is_lower_bits";
} // namespace xla

View File

@ -0,0 +1,35 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_
namespace xla {
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
extern const char kXlaHostTransferRendezvousNameAttr[];
// XLA frontend attribute name which specifies original host transfer type.
// Value is XLA primitive type in lower case.
extern const char kXlaHostTransferOriginalTypeAttr[];
// XLA frontend attribute name which specifies whether a host transfer
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
// or "false".
extern const char kXlaHostTransferIsLowerBitsAttr[];
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_

View File

@ -31,6 +31,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:side_effect_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/side_effect_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
@ -152,8 +153,9 @@ class HostComputeOp : public XlaOpKernel {
input_shapes[i], &xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] =
channel_name;
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);
@ -211,8 +213,9 @@ class HostComputeOp : public XlaOpKernel {
const string channel_name = absl::StrCat(key_, "_htod_", i);
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] =
channel_name;
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_output_shapes->at(i).element_type());
b->SetFrontendAttributes(attrs);
@ -416,8 +419,8 @@ class SendToHostOp : public XlaOpKernel {
&xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);
@ -468,8 +471,8 @@ class RecvFromHostOp : public XlaOpKernel {
ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);