[TF2XLA] Add a utility function to check whether GraphDef can trigger XLA compilation
PiperOrigin-RevId: 329034140 Change-Id: I0ef45dfd5cbcaf9906ba9d76ed29fcc27406aa30
This commit is contained in:
parent
8f54070dd1
commit
be2870d0db
@ -922,6 +922,7 @@ tf_cc_test(
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/tf2xla:test_util",
|
||||
|
@ -518,10 +518,15 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `true` iff node has a given `attr` set to `true`. Returns `false`
|
||||
// both for the missing attr, and the attr set to `false`.
|
||||
static bool HasBoolAttr(const NodeDef& node, const char* attr) {
|
||||
const auto& it = node.attr().find(attr);
|
||||
return it != node.attr().end() && it->second.b();
|
||||
}
|
||||
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||
return it != node_def.attr().end() && it->second.b();
|
||||
return HasBoolAttr(node_def, kXlaMustCompileAttr);
|
||||
}
|
||||
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
@ -564,4 +569,58 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static auto const ops_triggering_xla_compilation =
|
||||
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
"XlaDequantize",
|
||||
"XlaDot",
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
"XlaGather",
|
||||
"XlaIf",
|
||||
"XlaKeyValueSort",
|
||||
"XlaPad",
|
||||
"XlaRecv",
|
||||
"XlaReduce",
|
||||
"XlaReduceWindow",
|
||||
"XlaReplicaId",
|
||||
"XlaScatter",
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile"};
|
||||
|
||||
static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
|
||||
return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
|
||||
HasBoolAttr(node, kXlaMustCompileAttr) ||
|
||||
HasBoolAttr(node, kXlaCompileAttr) ||
|
||||
HasBoolAttr(node, kXlaScopeAttr) ||
|
||||
HasBoolAttr(node, kXlaInternalScopeAttr) ||
|
||||
ops_triggering_xla_compilation->count(node.op());
|
||||
}
|
||||
|
||||
bool CanTriggerXlaCompilation(const GraphDef& graph) {
|
||||
for (const FunctionDef& function : graph.library().function()) {
|
||||
for (const NodeDef& node : function.node_def()) {
|
||||
if (NodeCanTriggerXlaCompilation(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (NodeCanTriggerXlaCompilation(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -283,6 +283,9 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
// Check whether graph can trigger XLA compilation.
|
||||
bool CanTriggerXlaCompilation(const GraphDef& graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
FunctionDef identity_func = FunctionDefHelper::Create(
|
||||
"IdentityFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
*library.add_function() = identity_func;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("IdentityFunc");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
FunctionDef sort_func = FunctionDefHelper::Create(
|
||||
"SortFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
*library.add_function() = sort_func;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("SortFunc");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
AttrValue true_attribute;
|
||||
true_attribute.set_b(true);
|
||||
|
||||
FunctionDef identity_func = FunctionDefHelper::Create(
|
||||
"IdentityFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
(*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
|
||||
|
||||
FunctionDef call_identity = FunctionDefHelper::Create(
|
||||
"CallIdentity",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"z:float"}, /*attr_def=*/{},
|
||||
/*node_def=*/
|
||||
{{{"func_call"},
|
||||
"PartitionedCall",
|
||||
{"x"},
|
||||
{{"Tin", DataTypeSlice({DT_FLOAT})},
|
||||
{"Tout", DataTypeSlice({DT_FLOAT})},
|
||||
{"f",
|
||||
FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
|
||||
{kXlaMustCompileAttr, true}}}},
|
||||
/*ret_def=*/{{"z", "func_call:output:0"}});
|
||||
|
||||
*library.add_function() = identity_func;
|
||||
*library.add_function() = call_identity;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("CallIdentity");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope";
|
||||
// only when auto_jit is ON.
|
||||
const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
|
||||
|
||||
const char* const kXlaClusterIdAttr = "_xla_compile_id";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
||||
|
||||
// The id of the compiled cluster.
|
||||
extern const char* const kXlaClusterIdAttr; // "_xla_compile_id"
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -34,9 +35,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
|
||||
"_xla_compile_id";
|
||||
|
||||
namespace {
|
||||
|
||||
const char* const kXlaClusterOutput = "XlaClusterOutput";
|
||||
@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
string name;
|
||||
// Only consider nodes being compiled.
|
||||
if (!GetNodeAttr(n->attrs(),
|
||||
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
|
||||
.ok())
|
||||
continue;
|
||||
if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
|
||||
// Early return for any node with a device that is not a CPU or GPU.
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
|
||||
@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
retvals[i]->AddAttr("index", i);
|
||||
}
|
||||
|
||||
AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
|
||||
call_def);
|
||||
AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def);
|
||||
AddNodeAttr("_variable_start_index", variable_start_index, call_def);
|
||||
|
||||
// Uniquify the function name.
|
||||
@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
// O(n) pass over the edges.
|
||||
for (const Edge* e : (*graph)->edges()) {
|
||||
if (!e->IsControlEdge() &&
|
||||
e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
|
||||
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
|
||||
e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr &&
|
||||
e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr &&
|
||||
e->dst()->type_string() != kXlaClusterOutput) {
|
||||
return errors::InvalidArgument(
|
||||
"Undeclared output of XLA computation. Some common causes of this "
|
||||
@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
||||
auto output = absl::make_unique<Graph>((*graph)->op_registry());
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true,
|
||||
&output, flib_def),
|
||||
EncapsulateSubgraphsInFunctions(
|
||||
kXlaClusterIdAttr, **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true, &output, flib_def),
|
||||
"EncapsulateXlaComputationsPass failed");
|
||||
graph->swap(output);
|
||||
return Status::OK();
|
||||
@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
// while iterating.
|
||||
std::vector<Node*> launch_nodes;
|
||||
for (Node* n : graph->nodes()) {
|
||||
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
|
||||
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr);
|
||||
if (!name.empty()) {
|
||||
launch_nodes.push_back(n);
|
||||
}
|
||||
|
@ -34,8 +34,6 @@ namespace tensorflow {
|
||||
// XlaLaunch operators.
|
||||
class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
static const char* const kXlaClusterAttr; // _xla_compile_id
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
// The following methods are public only for unit tests.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
@ -46,19 +47,18 @@ static std::unique_ptr<Graph> MakeOuterGraph(
|
||||
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
|
||||
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(
|
||||
NodeDefBuilder("launch0", function, &flib_def)
|
||||
.Input(a.node()->name(), 0, DT_INT32)
|
||||
.Input(b.node()->name(), 0, DT_FLOAT)
|
||||
.Input(c.node()->name(), 0, DT_INT32)
|
||||
.Input(d.node()->name(), 0, DT_FLOAT)
|
||||
.Input(u.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(v.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(w.node()->name(), 0, DT_RESOURCE)
|
||||
.Device("/gpu:0")
|
||||
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
|
||||
.Attr("_variable_start_index", 4)
|
||||
.Finalize(&def));
|
||||
TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
|
||||
.Input(a.node()->name(), 0, DT_INT32)
|
||||
.Input(b.node()->name(), 0, DT_FLOAT)
|
||||
.Input(c.node()->name(), 0, DT_INT32)
|
||||
.Input(d.node()->name(), 0, DT_FLOAT)
|
||||
.Input(u.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(v.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(w.node()->name(), 0, DT_RESOURCE)
|
||||
.Device("/gpu:0")
|
||||
.Attr(kXlaClusterIdAttr, "launch0")
|
||||
.Attr("_variable_start_index", 4)
|
||||
.Finalize(&def));
|
||||
|
||||
Status status;
|
||||
Node* launch = scope.graph()->AddNode(def, &status);
|
||||
@ -107,7 +107,7 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
|
||||
auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
node->set_requested_device("/gpu:0");
|
||||
};
|
||||
|
||||
@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
|
||||
: ops::Add(scope.WithOpName("E"), a1, a0);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
|
||||
"launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
};
|
||||
add_attrs(e.node());
|
||||
|
||||
@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
|
||||
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
node->set_requested_device("/gpu:0");
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user