[XLA] Break dependencies of XLA on TF.
PiperOrigin-RevId: 352649426 Change-Id: Ia65d1ff9e4df69467d5ad7264bb052c54a661c0a
This commit is contained in:
parent
daa61db597
commit
b885741721
@ -34,15 +34,6 @@ const char kXlaIsPlaceholderForTailOcAttrName[] =
|
||||
const char kXlaOriginalOutsideCompilationNodeName[] =
|
||||
"_xla_original_oc_node_name";
|
||||
|
||||
const char kXlaHostTransferRendezvousNameAttr[] =
|
||||
"_xla_host_transfer_rendezvous";
|
||||
|
||||
const char kXlaHostTransferOriginalTypeAttr[] =
|
||||
"_xla_host_transfer_original_type";
|
||||
|
||||
const char kXlaHostTransferIsLowerBitsAttr[] =
|
||||
"_xla_host_transfer_is_lower_bits";
|
||||
|
||||
Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
|
||||
if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
|
||||
return errors::InvalidArgument("Node ", node->DebugString(),
|
||||
|
||||
@ -64,18 +64,6 @@ bool HasSideEffectingNodes(const Graph& g);
|
||||
Status ParseHostComputeCoreList(absl::Span<const string> list_from_attr,
|
||||
std::map<string, int>* host_compute_core);
|
||||
|
||||
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
|
||||
extern const char kXlaHostTransferRendezvousNameAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies original host transfer type.
|
||||
// Value is XLA primitive type in lower case.
|
||||
extern const char kXlaHostTransferOriginalTypeAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies whether a host transfer
|
||||
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
|
||||
// or "false".
|
||||
extern const char kXlaHostTransferIsLowerBitsAttr[];
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
|
||||
|
||||
@ -964,6 +964,12 @@ cc_library(
|
||||
hdrs = ["union_find.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_effect_util",
|
||||
srcs = ["side_effect_util.cc"],
|
||||
hdrs = ["side_effect_util.h"],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||
|
||||
@ -209,6 +209,7 @@ class Layout {
|
||||
return *this;
|
||||
}
|
||||
static constexpr int64 kDefaultMemorySpace = 0;
|
||||
static constexpr int64 kGenericFastMemorySpace = 1;
|
||||
int64 memory_space() const { return memory_space_; }
|
||||
Layout& set_memory_space(int64 value) {
|
||||
memory_space_ = value;
|
||||
|
||||
29
tensorflow/compiler/xla/side_effect_util.cc
Normal file
29
tensorflow/compiler/xla/side_effect_util.cc
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/side_effect_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
const char kXlaHostTransferRendezvousNameAttr[] =
|
||||
"_xla_host_transfer_rendezvous";
|
||||
|
||||
const char kXlaHostTransferOriginalTypeAttr[] =
|
||||
"_xla_host_transfer_original_type";
|
||||
|
||||
const char kXlaHostTransferIsLowerBitsAttr[] =
|
||||
"_xla_host_transfer_is_lower_bits";
|
||||
|
||||
} // namespace xla
|
||||
35
tensorflow/compiler/xla/side_effect_util.h
Normal file
35
tensorflow/compiler/xla/side_effect_util.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_
|
||||
|
||||
namespace xla {
|
||||
|
||||
// XLA frontend attribute name which specifies TensorFlow rendezvous name.
|
||||
extern const char kXlaHostTransferRendezvousNameAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies original host transfer type.
|
||||
// Value is XLA primitive type in lower case.
|
||||
extern const char kXlaHostTransferOriginalTypeAttr[];
|
||||
|
||||
// XLA frontend attribute name which specifies whether a host transfer
|
||||
// instruction is lower bits for a splitted X64 host transfer. Value is "true"
|
||||
// or "false".
|
||||
extern const char kXlaHostTransferIsLowerBitsAttr[];
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SIDE_EFFECT_UTIL_H_
|
||||
@ -31,6 +31,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:side_effect_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
@ -152,8 +153,9 @@ class HostComputeOp : public XlaOpKernel {
|
||||
input_shapes[i], &xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] =
|
||||
channel_name;
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
@ -211,8 +213,9 @@ class HostComputeOp : public XlaOpKernel {
|
||||
const string channel_name = absl::StrCat(key_, "_htod_", i);
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] =
|
||||
channel_name;
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_output_shapes->at(i).element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
@ -416,8 +419,8 @@ class SendToHostOp : public XlaOpKernel {
|
||||
&xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
@ -468,8 +471,8 @@ class RecvFromHostOp : public XlaOpKernel {
|
||||
ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[xla::kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user