[TF2XLA] Add a utility function to check whether GraphDef can trigger XLA compilation

PiperOrigin-RevId: 329034140
Change-Id: I0ef45dfd5cbcaf9906ba9d76ed29fcc27406aa30
This commit is contained in:
George Karpenkov 2020-08-28 16:57:03 -07:00 committed by TensorFlower Gardener
parent 8f54070dd1
commit be2870d0db
9 changed files with 202 additions and 37 deletions

View File

@ -922,6 +922,7 @@ tf_cc_test(
":xla_cpu_jit",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/tf2xla:test_util",

View File

@ -518,10 +518,15 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
}
}
// Returns `true` iff node has a given `attr` set to `true`. Returns `false`
// both for the missing attr, and the attr set to `false`.
static bool HasBoolAttr(const NodeDef& node, const char* attr) {
const auto& it = node.attr().find(attr);
return it != node.attr().end() && it->second.b();
}
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
return it != node_def.attr().end() && it->second.b();
return HasBoolAttr(node_def, kXlaMustCompileAttr);
}
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
@ -564,4 +569,58 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
return Status::OK();
}
static auto const ops_triggering_xla_compilation =
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
"XlaConv",
"XlaDequantize",
"XlaDot",
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",
"XlaGather",
"XlaIf",
"XlaKeyValueSort",
"XlaPad",
"XlaRecv",
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile"};
static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
HasBoolAttr(node, kXlaMustCompileAttr) ||
HasBoolAttr(node, kXlaCompileAttr) ||
HasBoolAttr(node, kXlaScopeAttr) ||
HasBoolAttr(node, kXlaInternalScopeAttr) ||
ops_triggering_xla_compilation->count(node.op());
}
bool CanTriggerXlaCompilation(const GraphDef& graph) {
for (const FunctionDef& function : graph.library().function()) {
for (const NodeDef& node : function.node_def()) {
if (NodeCanTriggerXlaCompilation(node)) {
return true;
}
}
}
for (const NodeDef& node : graph.node()) {
if (NodeCanTriggerXlaCompilation(node)) {
return true;
}
}
return false;
}
} // namespace tensorflow

View File

@ -283,6 +283,9 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Check whether graph can trigger XLA compilation.
bool CanTriggerXlaCompilation(const GraphDef& graph);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
"unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef identity_func = FunctionDefHelper::Create(
"IdentityFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
*library.add_function() = identity_func;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("IdentityFunc");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
}
TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef sort_func = FunctionDefHelper::Create(
"SortFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
*library.add_function() = sort_func;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("SortFunc");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
}
TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
AttrValue true_attribute;
true_attribute.set_b(true);
FunctionDef identity_func = FunctionDefHelper::Create(
"IdentityFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
(*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
FunctionDef call_identity = FunctionDefHelper::Create(
"CallIdentity",
/*in_def=*/{"x:float"},
/*out_def=*/{"z:float"}, /*attr_def=*/{},
/*node_def=*/
{{{"func_call"},
"PartitionedCall",
{"x"},
{{"Tin", DataTypeSlice({DT_FLOAT})},
{"Tout", DataTypeSlice({DT_FLOAT})},
{"f",
FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
{kXlaMustCompileAttr, true}}}},
/*ret_def=*/{{"z", "func_call:output:0"}});
*library.add_function() = identity_func;
*library.add_function() = call_identity;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("CallIdentity");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
}
} // namespace
} // namespace tensorflow

View File

@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope";
// only when auto_jit is ON.
const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
const char* const kXlaClusterIdAttr = "_xla_compile_id";
} // namespace tensorflow

View File

@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile"
extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
// The id of the compiled cluster.
extern const char* const kXlaClusterIdAttr; // "_xla_compile_id"
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -34,9 +35,6 @@ limitations under the License.
namespace tensorflow {
const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
"_xla_compile_id";
namespace {
const char* const kXlaClusterOutput = "XlaClusterOutput";
@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) {
for (Node* n : graph->nodes()) {
string name;
// Only consider nodes being compiled.
if (!GetNodeAttr(n->attrs(),
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
.ok())
continue;
if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
// Early return for any node with a device that is not a CPU or GPU.
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
retvals[i]->AddAttr("index", i);
}
AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
call_def);
AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def);
AddNodeAttr("_variable_start_index", variable_start_index, call_def);
// Uniquify the function name.
@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// O(n) pass over the edges.
for (const Edge* e : (*graph)->edges()) {
if (!e->IsControlEdge() &&
e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr &&
e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
"Undeclared output of XLA computation. Some common causes of this "
@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
auto output = absl::make_unique<Graph>((*graph)->op_registry());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true,
&output, flib_def),
EncapsulateSubgraphsInFunctions(
kXlaClusterIdAttr, **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true, &output, flib_def),
"EncapsulateXlaComputationsPass failed");
graph->swap(output);
return Status::OK();
@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// while iterating.
std::vector<Node*> launch_nodes;
for (Node* n : graph->nodes()) {
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr);
if (!name.empty()) {
launch_nodes.push_back(n);
}

View File

@ -34,8 +34,6 @@ namespace tensorflow {
// XlaLaunch operators.
class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
public:
static const char* const kXlaClusterAttr; // _xla_compile_id
Status Run(const GraphOptimizationPassOptions& options) override;
// The following methods are public only for unit tests.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
@ -46,19 +47,18 @@ static std::unique_ptr<Graph> MakeOuterGraph(
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
NodeDef def;
TF_CHECK_OK(
NodeDefBuilder("launch0", function, &flib_def)
.Input(a.node()->name(), 0, DT_INT32)
.Input(b.node()->name(), 0, DT_FLOAT)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Device("/gpu:0")
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
.Input(a.node()->name(), 0, DT_INT32)
.Input(b.node()->name(), 0, DT_FLOAT)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Device("/gpu:0")
.Attr(kXlaClusterIdAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
Status status;
Node* launch = scope.graph()->AddNode(def, &status);
@ -107,7 +107,7 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
node->set_requested_device("/gpu:0");
};
@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
: ops::Add(scope.WithOpName("E"), a1, a0);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
"launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
};
add_attrs(e.node());
@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
node->set_requested_device("/gpu:0");
};