diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 20efbe248d7..62e121420c3 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, return Status::OK(); } +xla::StatusOr> MakeCallNodesFromAttribute( + const Node& node, absl::string_view attr_name, + absl::string_view call_name) { + std::vector attr_lists; + TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists)); + + std::vector out; + for (int i = 0; i < attr_lists.size(); i++) { + out.emplace_back(); + NodeDef& inserted = out.back(); + inserted.set_name(absl::StrCat(call_name, "_", i)); + inserted.set_op(attr_lists[i].name()); + *inserted.mutable_attr() = attr_lists[i].attr(); + } + return out; +} + // Utility which searches for values in a sorted list by scanning over it once. // No matter how many times ScanForValue is called, the list is scanned at most // once. However, if a call to ScanForValue skips over a value, that value is @@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf( return is_compilable; } +bool RecursiveCompilabilityChecker::IsCompilableCase( + const Node& case_node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { + xla::StatusOr> calls = + MakeCallNodesFromAttribute(case_node, "branches", "branch"); + if (!calls.ok()) { + VLOG(2) << "Rejecting node " << case_node.name() << ": " + << "missing attribute 'branches'"; + return false; + } + + bool is_compilable = true; + + for (const NodeDef& call : *calls) { + is_compilable &= + IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes); + } + return is_compilable; +} + // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. @@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (op_filter_.require_always_compilable && node.IsCaseNode() && + !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { + LogNotCompilable(node, "unsupported case"); + return false; + } + if (!op_filter_.allow_stateful_rng_ops && IsStatefulRandomOp(node.type_string())) { absl::string_view uncompilable_reason = "stateful random op"; diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 3c1378bf764..65da072483b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker { // Whether ops known to have numerical accuracy issues should be considered // compilable.. bool allow_inaccurate_ops = false; + + // Require the function to be always compilable, regardless whether some + // control flow branches might be dead for a given input. + bool require_always_compilable = false; }; RecursiveCompilabilityChecker(OperationFilter op_filter, @@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker { NameAttrList* encapsulating_function, UncompilableNodesMap* uncompilable_nodes) const; + // Tests whether 'case_node' is compilable. Every operator in all branches + // must be compilable. + bool IsCompilableCase(const Node& case_node, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + // Returns compilability of node def retrieved from `node`'s attribute with // name `attr_name`. bool ExtractNodeDefAndCheckCompilability( diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 3851c66ba1a..9058b129589 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -34,7 +34,16 @@ limitations under the License. namespace tensorflow { namespace { +AttrValue FuncListAttr(const absl::Span names) { + AttrValue attr; + for (const char* name : names) { + attr.mutable_list()->add_func()->set_name(name); + } + return attr; +} + constexpr char kFunctionalIfNodeName[] = "If"; +constexpr char kFunctionalCaseNodeName[] = "Case"; constexpr char kFunctionalWhileNodeName[] = "While"; constexpr char kCompilableFunctionName[] = "CompilableFn"; constexpr char kCompilableFunctionNodeName[] = "n_c"; @@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test { op_filter_.allow_inaccurate_ops = false; op_filter_.allow_slow_ops = false; - checker_ = absl::make_unique(op_filter_, - device_type_); + checker_ = CreateCompilabilityChecker(); + } + + std::unique_ptr CreateCompilabilityChecker() { + return absl::make_unique(op_filter_, + device_type_); } FunctionLibraryRuntime* GetFunctionLibraryRuntime() { @@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { "unsupported op")); } +TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) { + FunctionDefLibrary flib; + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_c_uncompilable:float"}, + /*Attributes*/ {}, + // Node info + {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}}); + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionTwoName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_d_uncompilable:float"}, + /*Attribute*/ {}, + // Node info + {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}}); + + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib)); + auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32); + auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32); + std::vector inputes( + {NodeBuilder::NodeOut(placeholder.node())}); + Node* case_node; + TF_ASSERT_OK( + NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputes) + .Attr("branches", FuncListAttr({kUncompilableFunctionName, + kUncompilableFunctionTwoName})) + .Attr("Tout", {DT_INT32}) + .Finalize(root.graph(), &case_node)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + auto case_node_it = std::find_if( + graph->nodes().begin(), graph->nodes().end(), + [&](const Node* n) { return n->name() == kFunctionalCaseNodeName; }); + EXPECT_NE(case_node_it, graph->nodes().end()); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + op_filter_.require_always_compilable = false; + checker_ = CreateCompilabilityChecker(); + EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); + op_filter_.require_always_compilable = true; + checker_ = CreateCompilabilityChecker(); + EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); +} + TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { GraphDefBuilder b(GraphDefBuilder::kFailImmediately); Scope root = Scope::NewRootScope().ExitOnError(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 41f80203353..317e29d4a84 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - if (!RecursiveCompilabilityChecker{ - CreateOperationFilter(*registration), - DeviceType{registration->compilation_device_name}} - .IsCompilableNode(*node, lib_runtime)) { + RecursiveCompilabilityChecker::OperationFilter filter = + CreateOperationFilter(*registration); + filter.require_always_compilable = true; + + RecursiveCompilabilityChecker checker( + filter, DeviceType{registration->compilation_device_name}); + + if (!checker.IsCompilableNode(*node, lib_runtime)) { continue; } diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 1299ef27758..c1e6433c692 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -824,7 +824,6 @@ cuda_py_test( tags = [ "no_oss", ], - xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -843,7 +842,6 @@ cuda_py_test( "no_cuda11", "no_oss", ], - xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras",