Add function to preprocess TF graph before encapsulating XLA computations.
PiperOrigin-RevId: 217571411
This commit is contained in:
parent
d48968cc90
commit
b021a8b041
tensorflow/compiler
@ -411,7 +411,11 @@ cc_library(
|
||||
hdrs = ["encapsulate_util.h"],
|
||||
deps = [
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
|
@ -20,6 +20,10 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -36,10 +40,319 @@ absl::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a value to the node's list attribute.
|
||||
template <typename T>
|
||||
Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
|
||||
std::vector<T> attr_value;
|
||||
Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
|
||||
if (!s.ok() && s.code() != error::NOT_FOUND) {
|
||||
return s;
|
||||
}
|
||||
|
||||
n->ClearAttr(attr_name);
|
||||
attr_value.push_back(value);
|
||||
n->AddAttr(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replaces attribute value.
|
||||
template <typename T>
|
||||
void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
|
||||
n->ClearAttr(attr_name);
|
||||
n->AddAttr(attr_name, value);
|
||||
}
|
||||
|
||||
// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of
|
||||
// PreprocessForEncapsulation() for details.
|
||||
Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name) {
|
||||
// Gather edges to remove. We should not remove the edge while iterating.
|
||||
std::vector<const Edge*> edges_to_remove;
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (!e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto src_xla_computation =
|
||||
GetStringAttr(*e->src(), xla_computation_attr_name);
|
||||
auto dst_xla_computation =
|
||||
GetStringAttr(*e->dst(), xla_computation_attr_name);
|
||||
auto src_outside_compilation =
|
||||
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||
auto dst_outside_compilation =
|
||||
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||
|
||||
if (!src_xla_computation && !dst_xla_computation) {
|
||||
continue;
|
||||
} else if (src_xla_computation && !dst_xla_computation) {
|
||||
if (src_outside_compilation) {
|
||||
// Case 1d: outside compilation to host computation control edge.
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
||||
}
|
||||
} else if (!src_xla_computation && dst_xla_computation) {
|
||||
if (dst_outside_compilation) {
|
||||
// Case 1d: host computation control to outside compilation edge.
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
||||
}
|
||||
} else { // src_xla_computation && dst_xla_computation
|
||||
if (*src_xla_computation != *dst_xla_computation) {
|
||||
if (src_outside_compilation && dst_outside_compilation) {
|
||||
// Case 1c: outside compilation to outside compilation control edge.
|
||||
edges_to_remove.push_back(e);
|
||||
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
||||
} else if (src_outside_compilation && !dst_outside_compilation) {
|
||||
// Case 1b: outside compilation to another XLA computaition control
|
||||
// edge.
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
|
||||
*dst_xla_computation));
|
||||
} else if (!src_outside_compilation && dst_outside_compilation) {
|
||||
// Case 1b: another XLA computaition to outside compilation control
|
||||
// edge.
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
|
||||
*src_xla_computation));
|
||||
}
|
||||
} else { // *src_xla_computation == *dst_xla_computation
|
||||
if (src_outside_compilation && dst_outside_compilation) {
|
||||
if (*src_outside_compilation != *dst_outside_compilation) {
|
||||
// Case 1c: outside compilation to outside compilation control edge.
|
||||
edges_to_remove.push_back(e);
|
||||
|
||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
||||
}
|
||||
} else if (src_outside_compilation && !dst_outside_compilation) {
|
||||
// Case 1a: outside compilation to its XLA computation control edge.
|
||||
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
|
||||
} else if (!src_outside_compilation && dst_outside_compilation) {
|
||||
// Case 1a: XLA computation to outside compilation in it control edge.
|
||||
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto e : edges_to_remove) {
|
||||
g->RemoveEdge(e);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Step 2 for PreprocessForEncapsulation(). See comments of
|
||||
// PreprocessForEncapsulation() for details.
|
||||
Status ProcessXlaToXlaDataEdges(Graph* g,
|
||||
const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name) {
|
||||
// Gather edges between XLA computations. Notice that we do not store `Edge*`
|
||||
// directly because we remove some nodes while adding Identity nodes, and
|
||||
// those Edge pointers might be invalidated.
|
||||
struct EdgeInfo {
|
||||
int dst_input, dst_node_id;
|
||||
};
|
||||
std::vector<EdgeInfo> edges;
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto src_xla_computation =
|
||||
GetStringAttr(*e->src(), xla_computation_attr_name);
|
||||
auto dst_xla_computation =
|
||||
GetStringAttr(*e->dst(), xla_computation_attr_name);
|
||||
auto src_outside_compilation =
|
||||
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||
auto dst_outside_compilation =
|
||||
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||
if (!src_xla_computation || !dst_xla_computation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (*src_xla_computation != *dst_xla_computation) {
|
||||
if (src_outside_compilation || dst_outside_compilation) {
|
||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
||||
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
|
||||
}
|
||||
} else { // *src_xla_computation == *dst_xla_computation
|
||||
if (src_outside_compilation && dst_outside_compilation &&
|
||||
*src_outside_compilation != *dst_outside_compilation) {
|
||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
||||
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each XLA -> XLA edge, add an Identity node between src and dst.
|
||||
for (int i = 0; i < edges.size(); i++) {
|
||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
|
||||
Node* src = e->src();
|
||||
int src_output = e->src_output(), dst_input = e->dst_input();
|
||||
g->RemoveEdge(e);
|
||||
|
||||
// Create Identity node, and connect it between `src` and `dst`.
|
||||
string identity_node_name =
|
||||
absl::StrCat("bridge_", src->name(), "_", dst->name());
|
||||
DataType dtype = src->output_type(src_output);
|
||||
TF_ASSIGN_OR_RETURN(Node * identity_node,
|
||||
BuildIdentityNode(g, identity_node_name, dtype, src,
|
||||
/*requested_device=*/absl::nullopt));
|
||||
identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name());
|
||||
g->AddEdge(src, src_output, identity_node, 0);
|
||||
g->AddEdge(identity_node, 0, dst, dst_input);
|
||||
|
||||
// Replace `e->dst()` because its input node changed.
|
||||
NodeDef new_def = dst->def();
|
||||
*new_def.mutable_input(dst_input) = identity_node->name();
|
||||
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
|
||||
|
||||
// Other edge in `edges` might have `e->dst()` as src or dst
|
||||
// node. Before removing `e->dst()`, replace those edges with corresponding
|
||||
// edges for `dst_replace_node`.
|
||||
for (int j = i + 1; j < edges.size(); j++) {
|
||||
if (edges[j].dst_node_id == edges[i].dst_node_id) {
|
||||
edges[j].dst_node_id = dst_replace_node->id();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Step 3 for PreprocessForEncapsulation(). See comments of
|
||||
// PreprocessForEncapsulation() for details.
|
||||
Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
||||
Graph* g, const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name) {
|
||||
// Gather edges between outside compilation and host computation. Notice that
|
||||
// we do not store `Edge*` directly because we remove some nodes while adding
|
||||
// Identity nodes, and those Edge pointers might be invalidated.
|
||||
struct EdgeInfo {
|
||||
int dst_input, dst_node_id;
|
||||
bool is_host_to_outside_compilation;
|
||||
};
|
||||
std::vector<EdgeInfo> edges;
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr &&
|
||||
e->dst()->attrs().Find(xla_computation_attr_name) != nullptr &&
|
||||
e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) {
|
||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
|
||||
/*is_host_to_outside_compilation=*/true});
|
||||
VLOG(4) << "Host -> oc edge: " << e->DebugString();
|
||||
} else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr &&
|
||||
e->src()->attrs().Find(xla_computation_attr_name) != nullptr &&
|
||||
e->src()->attrs().Find(outside_compilation_attr_name) !=
|
||||
nullptr) {
|
||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
|
||||
/*is_host_to_outside_compilation=*/false});
|
||||
VLOG(4) << "Oc -> host edge: " << e->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||
// outside compilation node input.
|
||||
std::map<string, Node*> placeholders;
|
||||
for (int i = 0; i < edges.size(); i++) {
|
||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
|
||||
Node* src = e->src();
|
||||
int src_output = e->src_output(), dst_input = e->dst_input();
|
||||
g->RemoveEdge(e);
|
||||
|
||||
// Find or create placeholder node.
|
||||
string new_name =
|
||||
edges[i].is_host_to_outside_compilation
|
||||
? absl::StrCat(src->name(), "_host_to_oc_placeholder")
|
||||
: absl::StrCat(src->name(), "_oc_to_host_placeholder");
|
||||
auto iter = placeholders.find(new_name);
|
||||
Node* placeholder_node;
|
||||
if (iter == placeholders.end()) {
|
||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||
placeholder_builder.Attr("dtype", src->output_type(src_output));
|
||||
if (edges[i].is_host_to_outside_compilation) {
|
||||
placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName,
|
||||
src->name());
|
||||
placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName,
|
||||
src_output);
|
||||
// If this placeholder node is in outside compilation, we need to set
|
||||
// `xla_computation_attr_name` and `outside_compilation_attr_name`.
|
||||
string xla_computation_attr, outside_compilation_attr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name,
|
||||
&xla_computation_attr));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
|
||||
outside_compilation_attr_name,
|
||||
&outside_compilation_attr));
|
||||
placeholder_builder.Attr(xla_computation_attr_name,
|
||||
xla_computation_attr);
|
||||
placeholder_builder.Attr(outside_compilation_attr_name,
|
||||
outside_compilation_attr);
|
||||
} else {
|
||||
placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName,
|
||||
src->name());
|
||||
placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName,
|
||||
src_output);
|
||||
}
|
||||
NodeDef placeholder_def;
|
||||
TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
|
||||
Status s;
|
||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
placeholders[new_name] = placeholder_node;
|
||||
} else {
|
||||
placeholder_node = iter->second;
|
||||
}
|
||||
g->AddEdge(placeholder_node, 0, dst, dst_input);
|
||||
g->RemoveEdge(e);
|
||||
|
||||
// Replace `e->dst()` because its input node changed.
|
||||
NodeDef new_def = dst->def();
|
||||
*new_def.mutable_input(dst_input) = placeholder_node->name();
|
||||
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
|
||||
|
||||
// Other edge in `edges` might have `e->dst()` as src or dst
|
||||
// node. Before removing `e->dst()`, replace those edges with corresponding
|
||||
// edges for `dst_replace_node`.
|
||||
for (int j = i + 1; j < edges.size(); j++) {
|
||||
if (edges[j].dst_node_id == edges[i].dst_node_id) {
|
||||
edges[j].dst_node_id = dst_replace_node->id();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
|
||||
|
||||
const char kXlaConnectedToXlaComputationAttrName[] =
|
||||
"_xla_connected_to_xla_computation";
|
||||
const char kXlaConnectedFromXlaComputationAttrName[] =
|
||||
"_xla_connected_from_xla_computation";
|
||||
const char kXlaConnectedToOtherXlaComputationAttrName[] =
|
||||
"_xla_connected_to_other_xla_computation";
|
||||
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
|
||||
"_xla_connected_from_other_xla_computation";
|
||||
const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies";
|
||||
const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src";
|
||||
const char kOutsideCompilationToHostOriginalNodeAttrName[] =
|
||||
"_xla_oc_to_host_node_name";
|
||||
const char kOutsideCompilationToHostSrcOutputAttrName[] =
|
||||
"_xla_oc_to_host_src_output";
|
||||
const char kHostToOutsideCompilationOriginalNodeAttrName[] =
|
||||
"_xla_host_to_oc_node_name";
|
||||
const char kHostToOutsideCompilationSrcOutputAttrName[] =
|
||||
"_xla_host_to_oc_src_output";
|
||||
|
||||
Status PerformStaticShapeInferenceBeforeEncapsulation(
|
||||
Graph* g, const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name) {
|
||||
@ -91,4 +404,16 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PreprocessForEncapsulation(Graph* g,
|
||||
const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name) {
|
||||
TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name,
|
||||
outside_compilation_attr_name));
|
||||
TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name,
|
||||
outside_compilation_attr_name));
|
||||
TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
||||
g, xla_computation_attr_name, outside_compilation_attr_name));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -44,6 +44,77 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
|
||||
Graph* g, const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name);
|
||||
|
||||
// Attribute indicating that some ops in this node's XLA computation has control
|
||||
// dependency on this node. Attribute value will always be "true".
|
||||
extern const char kXlaConnectedToXlaComputationAttrName[];
|
||||
|
||||
// Attribute indicating that this node has control dependency on some ops in
|
||||
// this node's XLA computation. Attribute value will always be "true".
|
||||
extern const char kXlaConnectedFromXlaComputationAttrName[];
|
||||
|
||||
// Attribute indicating that some ops in other XLA computation has control
|
||||
// dependency on this node. Attribute value will be a list of string (XLA
|
||||
// computation names).
|
||||
extern const char kXlaConnectedToOtherXlaComputationAttrName[];
|
||||
|
||||
// Attribute indicating that this node has control dependency on some ops in
|
||||
// other XLA computation. Attribute value will be a list of string (XLA
|
||||
// computation names).
|
||||
extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
|
||||
|
||||
// Attribute indicating that this node has control dependencies on some other
|
||||
// nodes. Attribute value will be a list of string (node names).
|
||||
extern const char kXlaControlDependenciesAttrName[];
|
||||
|
||||
// Attribute indicating that this is an Identity node added to act as a bridge
|
||||
// between different XLA computations. Attribute value will be string (source
|
||||
// node name).
|
||||
extern const char kBridgeSourceNodeAttrName[];
|
||||
|
||||
// Attribute indicating that this is an Placeholder node added to act as a
|
||||
// temporary input node for an outside compilation node. Attribute value will be
|
||||
// string (original input node name).
|
||||
extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
|
||||
|
||||
// Attribute indicating that this is an Placeholder node added to act as a
|
||||
// temporary input node for an outside compilation node. Attribute value will be
|
||||
// int (src_output for original edge).
|
||||
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
|
||||
|
||||
// Attribute indicating that this is an Placeholder node added to act as a
|
||||
// temporary input node for an host node. Attribute value will be string
|
||||
// (original input node name).
|
||||
extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
|
||||
|
||||
// Attribute indicating that this is an Placeholder node added to act as a
|
||||
// temporary input node for a host node. Attribute value will be int (src_output
|
||||
// for original edge).
|
||||
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
|
||||
|
||||
// Preprocesses the graph for encapsulation. It will perform the following
|
||||
// operations in order:
|
||||
//
|
||||
// 1a. For control edges between outside compilation and its XLA computation,
|
||||
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
|
||||
// outside compilation node.
|
||||
// 1b. For control edges between outside compilation and another XLA
|
||||
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
|
||||
// = XLA computation node name" to the outside compilation node.
|
||||
// 1c. For control edges between different outside compilations, remove the edge
|
||||
// and add attr "kXlaControlDependenciesAttrName = src node name" to dst
|
||||
// node.
|
||||
// 1d. For control edges between outside compilation and host computation,
|
||||
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
|
||||
// name" to dst node.
|
||||
// 2. For data edges between different XLA computations, if either src or dst
|
||||
// is outside compilation, add an Identity node in between the edge. The
|
||||
// identity node will have attr kBridgeSourceNodeAttrName.
|
||||
// 3. For data edges between outside compilation and host computation, remove
|
||||
// the edge and create a Placeholder node as dst node's input.
|
||||
Status PreprocessForEncapsulation(Graph* g,
|
||||
const string& xla_computation_attr_name,
|
||||
const string& outside_compilation_attr_name);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||
|
@ -47,8 +47,8 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
|
||||
PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
|
||||
|
||||
// Check that only "add" node now has _xla_inferred_shapes attr.
|
||||
std::vector<Node*> nodes_with_inferred_shape;
|
||||
for (Node* n : g.nodes()) {
|
||||
std::vector<Node *> nodes_with_inferred_shape;
|
||||
for (Node *n : g.nodes()) {
|
||||
if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
|
||||
nodes_with_inferred_shape.push_back(n);
|
||||
}
|
||||
@ -65,4 +65,175 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
|
||||
EXPECT_EQ(shape_proto.dim(0).size(), 2);
|
||||
}
|
||||
|
||||
TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
||||
// Build the graph:
|
||||
// "const_0" and "const_1" in host computation
|
||||
// "add" = "const_0" + "const_1" in XLA computation 0
|
||||
// "identity0" = "add" in XLA computation 0 & outside compilation 0
|
||||
// "identity1" = "identity0" in XLA computation 0
|
||||
// "identity2" = "identity1" in host computation
|
||||
// "identity3" = "identity2" in XLA computation 1
|
||||
// "identity4" = "identity3" in XLA computation 1 & outside compilation 1
|
||||
// "identity5" = "identity4" in XLA computation 1
|
||||
// "identity6" = "identity5" in host computation
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
||||
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
||||
Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
|
||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
|
||||
Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
|
||||
Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_CHECK_OK(s.ToGraph(&g));
|
||||
auto node_index = g.BuildNodeNameIndex();
|
||||
|
||||
// Set XLA computation/outside compilation attr, and add control edges.
|
||||
Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
|
||||
*identity0_node = node_index["identity0"],
|
||||
*identity1_node = node_index["identity1"],
|
||||
*identity2_node = node_index["identity2"],
|
||||
*identity3_node = node_index["identity3"],
|
||||
*identity4_node = node_index["identity4"],
|
||||
*identity5_node = node_index["identity5"];
|
||||
add_node->AddAttr("_xla", "0");
|
||||
identity0_node->AddAttr("_xla", "0");
|
||||
identity0_node->AddAttr("_oc", "0");
|
||||
identity1_node->AddAttr("_xla", "0");
|
||||
identity3_node->AddAttr("_xla", "1");
|
||||
identity4_node->AddAttr("_xla", "1");
|
||||
identity4_node->AddAttr("_oc", "0");
|
||||
identity5_node->AddAttr("_xla", "1");
|
||||
// Case 1a: control edges between outside compilation and its XLA computation.
|
||||
g.AddControlEdge(add_node, identity0_node);
|
||||
g.AddControlEdge(identity0_node, identity1_node);
|
||||
// Case 1b: control edges between outside compilation and another XLA
|
||||
// computation.
|
||||
g.AddControlEdge(identity0_node, identity3_node);
|
||||
g.AddControlEdge(identity1_node, identity4_node);
|
||||
// Case 1c: control edges between different outside compilations.
|
||||
g.AddControlEdge(identity0_node, identity4_node);
|
||||
// Case 1d: control edges between outside compilation and host computation.
|
||||
g.AddControlEdge(const0_node, identity0_node);
|
||||
g.AddControlEdge(identity0_node, identity2_node);
|
||||
|
||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||
|
||||
// Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the
|
||||
// outside compilation node.
|
||||
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
|
||||
kXlaConnectedFromXlaComputationAttrName));
|
||||
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
|
||||
kXlaConnectedToXlaComputationAttrName));
|
||||
// Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
|
||||
// to the outside compilation node.
|
||||
std::vector<string> attr;
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
||||
kXlaConnectedToOtherXlaComputationAttrName, &attr));
|
||||
EXPECT_EQ(attr.size(), 1);
|
||||
EXPECT_EQ(attr[0], "1");
|
||||
attr.clear();
|
||||
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
|
||||
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
|
||||
EXPECT_EQ(attr.size(), 1);
|
||||
EXPECT_EQ(attr[0], "0");
|
||||
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
|
||||
attr.clear();
|
||||
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
|
||||
kXlaControlDependenciesAttrName, &attr));
|
||||
EXPECT_EQ(attr.size(), 1);
|
||||
EXPECT_EQ(attr[0], "identity0");
|
||||
// Case 1d: add attr "_xla_control_deps = src node name" to dst node.
|
||||
attr.clear();
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
||||
kXlaControlDependenciesAttrName, &attr));
|
||||
EXPECT_EQ(attr.size(), 1);
|
||||
EXPECT_EQ(attr[0], "const_0");
|
||||
attr.clear();
|
||||
TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
|
||||
kXlaControlDependenciesAttrName, &attr));
|
||||
EXPECT_EQ(attr.size(), 1);
|
||||
EXPECT_EQ(attr[0], "identity0");
|
||||
}
|
||||
|
||||
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
// Build the graph:
|
||||
// "const_0" and "const_1" in host computation
|
||||
// "add0" = "const_0" + "const_1" in XLA computation 0
|
||||
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
||||
// "identity0" = "add1" in XLA computation 0
|
||||
// "add2" = "add1" + "identity0" in host computation
|
||||
// "add3" = "add1" + "add2" in XLA computation 1
|
||||
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
|
||||
// "identity1" = "add4" in XLA computation 1
|
||||
// "identity2" = "identity1" in host computation
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
||||
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
||||
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
||||
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
||||
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
||||
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
||||
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_CHECK_OK(s.ToGraph(&g));
|
||||
auto node_index = g.BuildNodeNameIndex();
|
||||
|
||||
// Set XLA computation/outside compilation attr.
|
||||
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
||||
*identity0_node = node_index["identity0"],
|
||||
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
||||
*identity1_node = node_index["identity1"];
|
||||
add0_node->AddAttr("_xla", "0");
|
||||
add1_node->AddAttr("_xla", "0");
|
||||
add1_node->AddAttr("_oc", "0");
|
||||
identity0_node->AddAttr("_xla", "0");
|
||||
add3_node->AddAttr("_xla", "1");
|
||||
add4_node->AddAttr("_xla", "1");
|
||||
add4_node->AddAttr("_oc", "0");
|
||||
identity1_node->AddAttr("_xla", "1");
|
||||
|
||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||
|
||||
// Check input nodes for related data edges.
|
||||
node_index = g.BuildNodeNameIndex();
|
||||
// Step 2: add an Identity node between different XLA computations.
|
||||
Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
|
||||
EXPECT_NE(bridge_add1_add3, nullptr);
|
||||
string str;
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
|
||||
EXPECT_EQ(str, "add1");
|
||||
Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
|
||||
EXPECT_NE(bridge_identity0_add4, nullptr);
|
||||
// Step 3: add placeholder for edges between host computation and outside
|
||||
// compilation.
|
||||
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
|
||||
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
|
||||
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
||||
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
||||
EXPECT_EQ(str, "add1");
|
||||
int i;
|
||||
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
||||
kOutsideCompilationToHostSrcOutputAttrName, &i));
|
||||
EXPECT_EQ(i, 0);
|
||||
add4_node = node_index["add4"];
|
||||
ASSERT_NE(add4_node, nullptr);
|
||||
EXPECT_EQ(add4_node->def().input(0),
|
||||
"bridge_identity0_add4_host_to_oc_placeholder");
|
||||
Node *identity0_host_to_oc_placeholder =
|
||||
node_index["bridge_identity0_add4_host_to_oc_placeholder"];
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
||||
EXPECT_EQ(str, "bridge_identity0_add4");
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
||||
EXPECT_EQ(i, 0);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -465,4 +464,60 @@ Status CachedFunctionHandles::ReleaseAllHandles() {
|
||||
return result;
|
||||
}
|
||||
|
||||
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
|
||||
// Create the replacement node.
|
||||
Status s;
|
||||
Node* new_node = g->AddNode(node_def, &s);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
||||
// Record original node's output edges and remove them first. This is to avoid
|
||||
// multiple producers for dst nodes' input.
|
||||
std::vector<OutEdgeInfo> out_edge_info;
|
||||
std::vector<const Edge*> out_edges;
|
||||
for (const Edge* edge : n->out_edges()) {
|
||||
out_edges.push_back(edge);
|
||||
out_edge_info.push_back(
|
||||
{edge->dst(), edge->src_output(), edge->dst_input()});
|
||||
}
|
||||
for (const Edge* edge : out_edges) {
|
||||
g->RemoveEdge(edge);
|
||||
}
|
||||
|
||||
// Add original node's input and output edges to the replacement node.
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
|
||||
in_edge->dst_input());
|
||||
}
|
||||
for (const OutEdgeInfo& out_edge : out_edge_info) {
|
||||
g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
|
||||
}
|
||||
|
||||
// Remove the original node.
|
||||
g->RemoveNode(n);
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
||||
xla::StatusOr<Node*> BuildIdentityNode(
|
||||
Graph* graph, const string& node_name, DataType dtype, const Node* input,
|
||||
absl::optional<string> requested_device) {
|
||||
// Create identity node.
|
||||
NodeDef ndef;
|
||||
ndef.set_name(node_name);
|
||||
ndef.set_op("Identity");
|
||||
if (input) {
|
||||
ndef.add_input(input->name());
|
||||
}
|
||||
if (requested_device) {
|
||||
ndef.set_device(*requested_device);
|
||||
}
|
||||
AddNodeAttr("T", dtype, &ndef);
|
||||
Status s;
|
||||
Node* id_node = graph->AddNode(ndef, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
return id_node;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -168,6 +169,20 @@ class CachedFunctionHandles {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
|
||||
};
|
||||
|
||||
// Struct for node's output edge info.
|
||||
struct OutEdgeInfo {
|
||||
Node* dst;
|
||||
int src_output, dst_input;
|
||||
};
|
||||
|
||||
// Replaces node `n` with a new node whose NodeDef is `node_def`.
|
||||
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
|
||||
|
||||
// Helper function that builds an Identity node.
|
||||
xla::StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
|
||||
DataType dtype, const Node* input,
|
||||
absl::optional<string> requested_device);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
Loading…
Reference in New Issue
Block a user